diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index a1fcf5e..2c361ef 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -25,6 +25,7 @@ jobs: with: version: latest test: + environment: Unit Test runs-on: 'ubuntu-latest' steps: - name: Checkout @@ -35,5 +36,10 @@ jobs: go-version: '1.17' check-latest: true - name: Test + env: + API_KEY: ${{ secrets.API_KEY }} + SECRET_KEY: ${{ secrets.SECRET_KEY }} + EU_API_KEY: ${{ secrets.EU_API_KEY }} + EU_SECRET_KEY: ${{ secrets.EU_SECRET_KEY }} run: | go test ./... diff --git a/.gitignore b/.gitignore index ff53c67..9e9437e 100644 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,4 @@ xpmt .DS_Store cmd/xpmt/bin/ +pkg/experiment/local/.env diff --git a/README.md b/README.md index 300c716..ccf1fcf 100644 --- a/README.md +++ b/README.md @@ -110,3 +110,12 @@ Fetch variants for a user given an experiment user JSON object ``` > Note: must use single quotes around JSON object string + +### Running unit tests suite +To set up for running test on local, create a `.env` file in `pkg/experiment/local` with following +contents, and replace `{API_KEY}` and `{SECRET_KEY}` (or `{EU_API_KEY}` and `{EU_SECRET_KEY}` for EU data center) for the project in test: + +``` +API_KEY={API_KEY} +SECRET_KEY={SECRET_KEY} +``` diff --git a/go.mod b/go.mod index d282c4d..bad5092 100644 --- a/go.mod +++ b/go.mod @@ -4,4 +4,9 @@ go 1.12 require github.com/spaolacci/murmur3 v1.1.0 -require github.com/amplitude/analytics-go v1.0.1 +require ( + github.com/amplitude/analytics-go v1.0.1 + github.com/jarcoal/httpmock v1.3.1 + github.com/joho/godotenv v1.5.1 + github.com/stretchr/testify v1.9.0 +) diff --git a/go.sum b/go.sum index 40e774e..ec3c7e2 100644 --- a/go.sum +++ b/go.sum @@ -5,18 +5,27 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/jarcoal/httpmock v1.3.1 h1:iUx3whfZWVf3jT01hQTO/Eo5sAYtB2/rqaUuOtpInww= +github.com/jarcoal/httpmock v1.3.1/go.mod h1:3yb8rc4BI7TCBhFY8ng0gjuLKJNquuDNiPaZjnENuYg= +github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= +github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= +github.com/maxatome/go-testdeep v1.12.0 h1:Ql7Go8Tg0C1D/uMMX59LAoYK7LffeJQ6X2T04nTH68g= +github.com/maxatome/go-testdeep v1.12.0/go.mod h1:lPZc/HAcJMP92l7yI6TRz1aZN5URwUBUAfUNvrclaNM= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/spaolacci/murmur3 v1.1.0 h1:7c1g84S4BPRrfL5Xrdp6fOJ206sU9y293DDHaoy0bLI= github.com/spaolacci/murmur3 v1.1.0/go.mod h1:JwIasOWyU6f++ZhiEuf87xNszmSA2myDM2Kzu9HwQUA= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= -github.com/stretchr/objx v0.5.0 h1:1zr/of2m5FGMsad5YfcqgdqdWrIhu+EBEJRhR1U7z/c= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= +github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= -github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= +github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/internal/evaluation/context.go b/internal/evaluation/context.go index 62d4147..002c2c1 100644 --- a/internal/evaluation/context.go +++ b/internal/evaluation/context.go @@ -1,6 +1,8 @@ package evaluation -import "github.com/amplitude/experiment-go-server/pkg/experiment" +import ( + "github.com/amplitude/experiment-go-server/pkg/experiment" +) func UserToContext(user *experiment.User) map[string]interface{} { if user == nil { @@ -8,6 +10,7 @@ func UserToContext(user *experiment.User) map[string]interface{} { } context := make(map[string]interface{}) userMap := make(map[string]interface{}) + if len(user.UserId) != 0 { userMap["user_id"] = user.UserId } @@ -56,6 +59,58 @@ func UserToContext(user *experiment.User) map[string]interface{} { if len(user.UserProperties) != 0 { userMap["user_properties"] = user.UserProperties } + if len(user.Groups) != 0 { + userMap["groups"] = user.Groups + } + if len(user.CohortIds) != 0 { + userMap["cohort_ids"] = extractKeys(user.CohortIds) + } + context["user"] = userMap + + if user.Groups == nil { + return context + } + + groups := make(map[string]interface{}) + for groupType, groupNames := range user.Groups { + if len(groupNames) > 0 { + groupName := groupNames[0] + groupNameMap := map[string]interface{}{ + "group_name": groupName, + } + + if user.GroupProperties != nil { + if groupPropertiesType, ok := user.GroupProperties[groupType]; ok { + if groupPropertiesName, ok := groupPropertiesType[groupName]; ok { + groupNameMap["group_properties"] = groupPropertiesName + } + } + } + + if user.GroupCohortIds != nil { + if groupCohortIdsType, ok := user.GroupCohortIds[groupType]; ok { + if groupCohortIdsName, ok := groupCohortIdsType[groupName]; ok { + groupNameMap["cohort_ids"] = extractKeys(groupCohortIdsName) + } + } + } + + groups[groupType] = groupNameMap + } + } + + if len(groups) > 0 { + context["groups"] = groups + } + return context } + +func extractKeys(m map[string]struct{}) []string { + keys := make([]string, 0, len(m)) + for key := range m { + keys = append(keys, key) + } + return keys +} diff --git a/pkg/experiment/local/client.go b/pkg/experiment/local/client.go index 7ba17e1..07286a7 100644 --- a/pkg/experiment/local/client.go +++ b/pkg/experiment/local/client.go @@ -27,10 +27,13 @@ type Client struct { config *Config client *http.Client poller *poller - flags map[string]*evaluation.Flag flagsMutex *sync.RWMutex engine *evaluation.Engine assignmentService *assignmentService + cohortStorage cohortStorage + flagConfigStorage flagConfigStorage + cohortLoader *cohortLoader + deploymentRunner *deploymentRunner } func Initialize(apiKey string, config *Config) *Client { @@ -43,23 +46,35 @@ func Initialize(apiKey string, config *Config) *Client { config = fillConfigDefaults(config) log := logger.New(config.Debug) var as *assignmentService - if config.AssignmentConfig != nil && config.AssignmentConfig.APIKey != "" { + if config.AssignmentConfig != nil && config.AssignmentConfig.APIKey != "" { amplitudeClient := amplitude.NewClient(config.AssignmentConfig.Config) as = &assignmentService{ amplitude: &litudeClient, - filter: newAssignmentFilter(config.AssignmentConfig.CacheCapacity), + filter: newAssignmentFilter(config.AssignmentConfig.CacheCapacity), } } + cohortStorage := newInMemoryCohortStorage() + flagConfigStorage := newInMemoryFlagConfigStorage() + var cohortLoader *cohortLoader + var deploymentRunner *deploymentRunner + if config.CohortSyncConfig != nil { + cohortDownloadApi := newDirectCohortDownloadApi(config.CohortSyncConfig.ApiKey, config.CohortSyncConfig.SecretKey, config.CohortSyncConfig.MaxCohortSize, config.CohortSyncConfig.CohortServerUrl, config.Debug) + cohortLoader = newCohortLoader(cohortDownloadApi, cohortStorage) + } + deploymentRunner = newDeploymentRunner(config, newFlagConfigApiV2(apiKey, config.ServerUrl, config.FlagConfigPollerRequestTimeout), flagConfigStorage, cohortStorage, cohortLoader) client = &Client{ - log: log, - apiKey: apiKey, - config: config, - client: &http.Client{}, - poller: newPoller(), - flags: make(map[string]*evaluation.Flag), - flagsMutex: &sync.RWMutex{}, - engine: evaluation.NewEngine(log), + log: log, + apiKey: apiKey, + config: config, + client: &http.Client{}, + poller: newPoller(), + flagsMutex: &sync.RWMutex{}, + engine: evaluation.NewEngine(log), assignmentService: as, + cohortStorage: cohortStorage, + flagConfigStorage: flagConfigStorage, + cohortLoader: cohortLoader, + deploymentRunner: deploymentRunner, } client.log.Debug("config: %v", *config) clients[apiKey] = client @@ -69,20 +84,10 @@ func Initialize(apiKey string, config *Config) *Client { } func (c *Client) Start() error { - result, err := c.doFlagsV2() + err := c.deploymentRunner.start() if err != nil { return err } - c.flags = result - c.poller.Poll(c.config.FlagConfigPollerInterval, func() { - result, err := c.doFlagsV2() - if err != nil { - return - } - c.flagsMutex.Lock() - c.flags = result - c.flagsMutex.Unlock() - }) return nil } @@ -110,10 +115,17 @@ func (c *Client) Evaluate(user *experiment.User, flagKeys []string) (map[string] } func (c *Client) EvaluateV2(user *experiment.User, flagKeys []string) (map[string]experiment.Variant, error) { - userContext := evaluation.UserToContext(user) - c.flagsMutex.RLock() - sortedFlags, err := topologicalSort(c.flags, flagKeys) - c.flagsMutex.RUnlock() + flagConfigs := c.flagConfigStorage.getFlagConfigs() + sortedFlags, err := topologicalSort(flagConfigs, flagKeys) + if err != nil { + return nil, err + } + c.requiredCohortsInStorage(sortedFlags) + enrichedUser, err := c.enrichUserWithCohorts(user, flagConfigs) + if err != nil { + return nil, err + } + userContext := evaluation.UserToContext(enrichedUser) if err != nil { return nil, err } @@ -149,9 +161,7 @@ func (c *Client) FlagsV2() (string, error) { // FlagMetadata returns a copy of the flag's metadata. If the flag is not found then nil is returned. func (c *Client) FlagMetadata(flagKey string) map[string]interface{} { - c.flagsMutex.RLock() - f := c.flags[flagKey] - c.flagsMutex.RUnlock() + f := c.flagConfigStorage.getFlagConfig(flagKey) if f == nil { return nil } @@ -329,3 +339,56 @@ func coerceString(value interface{}) string { } return fmt.Sprintf("%v", value) } + +func (c *Client) requiredCohortsInStorage(flagConfigs []*evaluation.Flag) { + storedCohortIDs := c.cohortStorage.getCohortIds() + for _, flag := range flagConfigs { + flagCohortIDs := getAllCohortIDsFromFlag(flag) + missingCohorts := difference(flagCohortIDs, storedCohortIDs) + + if len(missingCohorts) > 0 { + if c.config.CohortSyncConfig != nil { + c.log.Debug( + "Evaluating flag %s dependent on cohorts %v without %v in storage", + flag.Key, flagCohortIDs, missingCohorts, + ) + } else { + c.log.Debug( + "Evaluating flag %s dependent on cohorts %v without cohort syncing configured", + flag.Key, flagCohortIDs, + ) + } + } + } +} + +func (c *Client) enrichUserWithCohorts(user *experiment.User, flagConfigs map[string]*evaluation.Flag) (*experiment.User, error) { + flagConfigSlice := make([]*evaluation.Flag, 0, len(flagConfigs)) + + for _, value := range flagConfigs { + flagConfigSlice = append(flagConfigSlice, value) + } + groupedCohortIDs := getGroupedCohortIDsFromFlags(flagConfigSlice) + + if cohortIDs, ok := groupedCohortIDs[userGroupType]; ok { + if len(cohortIDs) > 0 && user.UserId != "" { + user.CohortIds = c.cohortStorage.getCohortsForUser(user.UserId, cohortIDs) + } + } + + if user.Groups != nil { + for groupType, groupNames := range user.Groups { + groupName := "" + if len(groupNames) > 0 { + groupName = groupNames[0] + } + if groupName == "" { + continue + } + if cohortIDs, ok := groupedCohortIDs[groupType]; ok { + user.AddGroupCohortIds(groupType, groupName, c.cohortStorage.getCohortsForGroup(groupType, groupName, cohortIDs)) + } + } + } + return user, nil +} diff --git a/pkg/experiment/local/client_eu_test.go b/pkg/experiment/local/client_eu_test.go new file mode 100644 index 0000000..0bca7c2 --- /dev/null +++ b/pkg/experiment/local/client_eu_test.go @@ -0,0 +1,55 @@ +package local + +import ( + "github.com/amplitude/experiment-go-server/pkg/experiment" + "github.com/joho/godotenv" + "log" + "os" + "testing" +) + +var clientEU *Client + +func init() { + err := godotenv.Load() + if err != nil { + log.Printf("Error loading .env file: %v", err) + } + projectApiKey := os.Getenv("EU_API_KEY") + secretKey := os.Getenv("EU_SECRET_KEY") + cohortSyncConfig := CohortSyncConfig{ + ApiKey: projectApiKey, + SecretKey: secretKey, + } + clientEU = Initialize("server-Qlp7XiSu6JtP2S3JzA95PnP27duZgQCF", + &Config{CohortSyncConfig: &cohortSyncConfig, ServerZone: EUServerZone}) + err = clientEU.Start() + if err != nil { + panic(err) + } +} + +func TestEvaluateV2CohortEU(t *testing.T) { + targetedUser := &experiment.User{UserId: "1", DeviceId: "0"} + nonTargetedUser := &experiment.User{UserId: "not_targeted", DeviceId: "0"} + flagKeys := []string{"sdk-local-evaluation-user-cohort"} + result, err := clientEU.EvaluateV2(targetedUser, flagKeys) + if err != nil { + t.Fatalf("Unexpected error %v", err) + } + variant := result["sdk-local-evaluation-user-cohort"] + if variant.Key != "on" { + t.Fatalf("Unexpected variant %v", variant) + } + if variant.Value != "on" { + t.Fatalf("Unexpected variant %v", variant) + } + result, err = clientEU.EvaluateV2(nonTargetedUser, flagKeys) + if err != nil { + t.Fatalf("Unexpected error %v", err) + } + variant = result["sdk-local-evaluation-user-cohort"] + if variant.Key != "off" { + t.Fatalf("Unexpected variant %v", variant) + } +} diff --git a/pkg/experiment/local/client_test.go b/pkg/experiment/local/client_test.go index be7cc67..bada72a 100644 --- a/pkg/experiment/local/client_test.go +++ b/pkg/experiment/local/client_test.go @@ -2,14 +2,28 @@ package local import ( "github.com/amplitude/experiment-go-server/pkg/experiment" + "github.com/joho/godotenv" + "log" + "os" "testing" ) var client *Client func init() { - client = Initialize("server-qz35UwzJ5akieoAdIgzM4m9MIiOLXLoz", nil) - err := client.Start() + err := godotenv.Load() + if err != nil { + log.Printf("Error loading .env file: %v", err) + } + projectApiKey := os.Getenv("API_KEY") + secretKey := os.Getenv("SECRET_KEY") + cohortSyncConfig := CohortSyncConfig{ + ApiKey: projectApiKey, + SecretKey: secretKey, + } + client = Initialize("server-qz35UwzJ5akieoAdIgzM4m9MIiOLXLoz", + &Config{CohortSyncConfig: &cohortSyncConfig}) + err = client.Start() if err != nil { panic(err) } @@ -52,7 +66,6 @@ func TestEvaluate(t *testing.T) { } } - func TestEvaluateV2AllFlags(t *testing.T) { user := &experiment.User{UserId: "test_user"} result, err := client.EvaluateV2(user, nil) @@ -157,3 +170,63 @@ func TestFlagMetadataLocalFlagKey(t *testing.T) { t.Fatalf("Unexpected metadata %v", md) } } + +func TestEvaluateV2Cohort(t *testing.T) { + targetedUser := &experiment.User{UserId: "12345"} + nonTargetedUser := &experiment.User{UserId: "not_targeted"} + flagKeys := []string{"sdk-local-evaluation-user-cohort-ci-test"} + result, err := client.EvaluateV2(targetedUser, flagKeys) + if err != nil { + t.Fatalf("Unexpected error %v", err) + } + variant := result["sdk-local-evaluation-user-cohort-ci-test"] + if variant.Key != "on" { + t.Fatalf("Unexpected variant %v", variant) + } + if variant.Value != "on" { + t.Fatalf("Unexpected variant %v", variant) + } + result, err = client.EvaluateV2(nonTargetedUser, flagKeys) + if err != nil { + t.Fatalf("Unexpected error %v", err) + } + variant = result["sdk-local-evaluation-user-cohort-ci-test"] + if variant.Key != "off" { + t.Fatalf("Unexpected variant %v", variant) + } +} + +func TestEvaluateV2GroupCohort(t *testing.T) { + targetedUser := &experiment.User{ + UserId: "12345", + DeviceId: "device_id", + Groups: map[string][]string{ + "org id": {"1"}, + }} + nonTargetedUser := &experiment.User{ + UserId: "12345", + DeviceId: "device_id", + Groups: map[string][]string{ + "org id": {"not_targeted"}, + }} + flagKeys := []string{"sdk-local-evaluation-group-cohort-ci-test"} + result, err := client.EvaluateV2(targetedUser, flagKeys) + if err != nil { + t.Fatalf("Unexpected error %v", err) + } + variant := result["sdk-local-evaluation-group-cohort-ci-test"] + if variant.Key != "on" { + t.Fatalf("Unexpected variant %v", variant) + } + if variant.Value != "on" { + t.Fatalf("Unexpected variant %v", variant) + } + result, err = client.EvaluateV2(nonTargetedUser, flagKeys) + if err != nil { + t.Fatalf("Unexpected error %v", err) + } + variant = result["sdk-local-evaluation-group-cohort-ci-test"] + if variant.Key != "off" { + t.Fatalf("Unexpected variant %v", variant) + } +} diff --git a/pkg/experiment/local/cohort.go b/pkg/experiment/local/cohort.go new file mode 100644 index 0000000..c94660b --- /dev/null +++ b/pkg/experiment/local/cohort.go @@ -0,0 +1,34 @@ +package local + +import "sort" + +const userGroupType = "User" + +type Cohort struct { + Id string + LastModified int64 + Size int + MemberIds []string + GroupType string +} + +func CohortEquals(c1, c2 *Cohort) bool { + if c1.Id != c2.Id || c1.LastModified != c2.LastModified || c1.Size != c2.Size || c1.GroupType != c2.GroupType { + return false + } + if len(c1.MemberIds) != len(c2.MemberIds) { + return false + } + + // Sort MemberIds before comparing + sort.Strings(c1.MemberIds) + sort.Strings(c2.MemberIds) + + for i := range c1.MemberIds { + if c1.MemberIds[i] != c2.MemberIds[i] { + return false + } + } + + return true +} diff --git a/pkg/experiment/local/cohort_download_api.go b/pkg/experiment/local/cohort_download_api.go new file mode 100644 index 0000000..2d5ab8c --- /dev/null +++ b/pkg/experiment/local/cohort_download_api.go @@ -0,0 +1,121 @@ +package local + +import ( + "encoding/base64" + "encoding/json" + "fmt" + "github.com/amplitude/experiment-go-server/internal/logger" + "github.com/amplitude/experiment-go-server/pkg/experiment" + "net/http" + "strconv" + "time" +) + +const cohortRequestDelay = 100 * time.Millisecond + +type cohortDownloadApi interface { + getCohort(cohortID string, cohort *Cohort) (*Cohort, error) +} + +type directCohortDownloadApi struct { + ApiKey string + SecretKey string + MaxCohortSize int + ServerUrl string + Debug bool + log *logger.Log +} + +func newDirectCohortDownloadApi(apiKey, secretKey string, maxCohortSize int, serverUrl string, debug bool) *directCohortDownloadApi { + api := &directCohortDownloadApi{ + ApiKey: apiKey, + SecretKey: secretKey, + MaxCohortSize: maxCohortSize, + ServerUrl: serverUrl, + Debug: debug, + log: logger.New(debug), + } + return api +} + +func (api *directCohortDownloadApi) getCohort(cohortID string, cohort *Cohort) (*Cohort, error) { + api.log.Debug("getCohortMembers(%s): start", cohortID) + errors := 0 + client := &http.Client{} + + for { + response, err := api.getCohortMembersRequest(client, cohortID, cohort) + if err != nil { + api.log.Error("getCohortMembers(%s): request-status error %d - %v", cohortID, errors, err) + errors++ + if errors >= 3 || func(err error) bool { + switch err.(type) { + case *cohortTooLargeException: + return true + default: + return false + } + }(err) { + return nil, err + } + time.Sleep(cohortRequestDelay) + continue + } + + if response.StatusCode == http.StatusOK { + var cohortInfo struct { + Id string `json:"cohortId"` + LastModified int64 `json:"lastModified"` + Size int `json:"size"` + MemberIds []string `json:"memberIds"` + GroupType string `json:"groupType"` + } + if err := json.NewDecoder(response.Body).Decode(&cohortInfo); err != nil { + return nil, err + } + api.log.Debug("getCohortMembers(%s): end - resultSize=%d", cohortID, cohortInfo.Size) + return &Cohort{ + Id: cohortInfo.Id, + LastModified: cohortInfo.LastModified, + Size: cohortInfo.Size, + MemberIds: cohortInfo.MemberIds, + GroupType: func() string { + if cohortInfo.GroupType == "" { + return userGroupType + } + return cohortInfo.GroupType + }(), + }, nil + } else if response.StatusCode == http.StatusNoContent { + api.log.Debug("getCohortMembers(%s): Cohort not modified", cohortID) + return nil, nil + } else if response.StatusCode == http.StatusRequestEntityTooLarge { + return nil, &cohortTooLargeException{Message: "Cohort exceeds max cohort size of " + strconv.Itoa(api.MaxCohortSize)} + } else { + return nil, &httpErrorResponseException{StatusCode: response.StatusCode, Message: "Unexpected response code"} + } + } +} + +func (api *directCohortDownloadApi) getCohortMembersRequest(client *http.Client, cohortID string, cohort *Cohort) (*http.Response, error) { + req, err := http.NewRequest("GET", api.buildCohortURL(cohortID, cohort), nil) + if err != nil { + return nil, err + } + req.Header.Set("Authorization", "Basic "+api.getBasicAuth()) + req.Header.Set("X-Amp-Exp-Library", fmt.Sprintf("experiment-go-server/%v", experiment.VERSION)) + return client.Do(req) +} + +func (api *directCohortDownloadApi) getBasicAuth() string { + auth := api.ApiKey + ":" + api.SecretKey + return base64.StdEncoding.EncodeToString([]byte(auth)) +} + +func (api *directCohortDownloadApi) buildCohortURL(cohortID string, cohort *Cohort) string { + url := api.ServerUrl + "/sdk/v1/cohort/" + cohortID + "?maxCohortSize=" + strconv.Itoa(api.MaxCohortSize) + if cohort != nil { + url += "&lastModified=" + strconv.FormatInt(cohort.LastModified, 10) + } + return url +} diff --git a/pkg/experiment/local/cohort_download_api_test.go b/pkg/experiment/local/cohort_download_api_test.go new file mode 100644 index 0000000..7ddc870 --- /dev/null +++ b/pkg/experiment/local/cohort_download_api_test.go @@ -0,0 +1,221 @@ +package local + +import ( + "net/http" + "testing" + + "github.com/jarcoal/httpmock" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +type MockCohortDownloadApi struct { + mock.Mock +} + +type cohortInfo struct { + Id string `json:"cohortId"` + LastModified int64 `json:"lastModified"` + Size int `json:"size"` + MemberIds []string `json:"memberIds"` + GroupType string `json:"groupType"` +} + +func (m *MockCohortDownloadApi) getCohort(cohortID string, cohort *Cohort) (*Cohort, error) { + args := m.Called(cohortID, cohort) + if args.Get(0) != nil { + return args.Get(0).(*Cohort), args.Error(1) + } + return nil, args.Error(1) +} + +func TestCohortDownloadApi(t *testing.T) { + httpmock.Activate() + defer httpmock.DeactivateAndReset() + + api := newDirectCohortDownloadApi("api", "secret", 15000, "https://server.amplitude.com", false) + + t.Run("test_cohort_download_success", func(t *testing.T) { + cohort := &Cohort{Id: "1234", LastModified: 0, Size: 1, MemberIds: []string{"user"}, GroupType: "userGroupType"} + response := cohortInfo{Id: "1234", LastModified: 0, Size: 1, MemberIds: []string{"user"}, GroupType: "userGroupType"} + + httpmock.RegisterResponder("GET", api.buildCohortURL("1234", cohort), + func(req *http.Request) (*http.Response, error) { + resp, err := httpmock.NewJsonResponse(200, response) + if err != nil { + return httpmock.NewStringResponse(500, ""), nil + } + return resp, nil + }, + ) + + resultCohort, err := api.getCohort("1234", cohort) + assert.NoError(t, err) + assert.Equal(t, cohort.Id, resultCohort.Id) + assert.Equal(t, cohort.LastModified, resultCohort.LastModified) + assert.Equal(t, cohort.Size, resultCohort.Size) + assert.Equal(t, cohort.MemberIds, resultCohort.MemberIds) + assert.Equal(t, cohort.GroupType, resultCohort.GroupType) + }) + + t.Run("test_cohort_download_many_202s_success", func(t *testing.T) { + cohort := &Cohort{Id: "1234", LastModified: 0, Size: 1, MemberIds: []string{"user"}, GroupType: "userGroupType"} + response := &cohortInfo{Id: "1234", LastModified: 0, Size: 1, MemberIds: []string{"user"}, GroupType: "userGroupType"} + + for i := 0; i < 9; i++ { + httpmock.RegisterResponder("GET", api.buildCohortURL("1234", cohort), + httpmock.NewStringResponder(202, ""), + ) + } + httpmock.RegisterResponder("GET", api.buildCohortURL("1234", cohort), + func(req *http.Request) (*http.Response, error) { + resp, err := httpmock.NewJsonResponse(200, response) + if err != nil { + return httpmock.NewStringResponse(500, ""), nil + } + return resp, nil + }, + ) + + resultCohort, err := api.getCohort("1234", cohort) + assert.NoError(t, err) + assert.Equal(t, cohort.Id, resultCohort.Id) + assert.Equal(t, cohort.LastModified, resultCohort.LastModified) + assert.Equal(t, cohort.Size, resultCohort.Size) + assert.Equal(t, cohort.MemberIds, resultCohort.MemberIds) + assert.Equal(t, cohort.GroupType, resultCohort.GroupType) + }) + + t.Run("test_cohort_request_status_with_two_failures_succeeds", func(t *testing.T) { + cohort := &Cohort{Id: "1234", LastModified: 0, Size: 1, MemberIds: []string{"user"}, GroupType: "userGroupType"} + response := &cohortInfo{Id: "1234", LastModified: 0, Size: 1, MemberIds: []string{"user"}, GroupType: "userGroupType"} + + httpmock.RegisterResponder("GET", api.buildCohortURL("1234", cohort), + httpmock.NewStringResponder(503, ""), + ) + httpmock.RegisterResponder("GET", api.buildCohortURL("1234", cohort), + httpmock.NewStringResponder(503, ""), + ) + httpmock.RegisterResponder("GET", api.buildCohortURL("1234", cohort), + func(req *http.Request) (*http.Response, error) { + resp, err := httpmock.NewJsonResponse(200, response) + if err != nil { + return httpmock.NewStringResponse(500, ""), nil + } + return resp, nil + }, + ) + + resultCohort, err := api.getCohort("1234", cohort) + assert.NoError(t, err) + assert.Equal(t, cohort.Id, resultCohort.Id) + assert.Equal(t, cohort.LastModified, resultCohort.LastModified) + assert.Equal(t, cohort.Size, resultCohort.Size) + assert.Equal(t, cohort.MemberIds, resultCohort.MemberIds) + assert.Equal(t, cohort.GroupType, resultCohort.GroupType) + }) + + t.Run("test_cohort_request_status_429s_keep_retrying", func(t *testing.T) { + cohort := &Cohort{Id: "1234", LastModified: 0, Size: 1, MemberIds: []string{"user"}, GroupType: "userGroupType"} + response := &cohortInfo{Id: "1234", LastModified: 0, Size: 1, MemberIds: []string{"user"}, GroupType: "userGroupType"} + + for i := 0; i < 9; i++ { + httpmock.RegisterResponder("GET", api.buildCohortURL("1234", cohort), + httpmock.NewStringResponder(429, ""), + ) + } + httpmock.RegisterResponder("GET", api.buildCohortURL("1234", cohort), + func(req *http.Request) (*http.Response, error) { + resp, err := httpmock.NewJsonResponse(200, response) + if err != nil { + return httpmock.NewStringResponse(500, ""), nil + } + return resp, nil + }, + ) + + resultCohort, err := api.getCohort("1234", cohort) + assert.NoError(t, err) + assert.Equal(t, cohort.Id, resultCohort.Id) + assert.Equal(t, cohort.LastModified, resultCohort.LastModified) + assert.Equal(t, cohort.Size, resultCohort.Size) + assert.Equal(t, cohort.MemberIds, resultCohort.MemberIds) + assert.Equal(t, cohort.GroupType, resultCohort.GroupType) + }) + + t.Run("test_group_cohort_download_success", func(t *testing.T) { + cohort := &Cohort{Id: "1234", LastModified: 0, Size: 1, MemberIds: []string{"group"}, GroupType: "org name"} + response := &cohortInfo{Id: "1234", LastModified: 0, Size: 1, MemberIds: []string{"group"}, GroupType: "org name"} + + httpmock.RegisterResponder("GET", api.buildCohortURL("1234", cohort), + func(req *http.Request) (*http.Response, error) { + resp, err := httpmock.NewJsonResponse(200, response) + if err != nil { + return httpmock.NewStringResponse(500, ""), nil + } + return resp, nil + }, + ) + + resultCohort, err := api.getCohort("1234", cohort) + assert.NoError(t, err) + assert.Equal(t, cohort.Id, resultCohort.Id) + assert.Equal(t, cohort.LastModified, resultCohort.LastModified) + assert.Equal(t, cohort.Size, resultCohort.Size) + assert.Equal(t, cohort.MemberIds, resultCohort.MemberIds) + assert.Equal(t, cohort.GroupType, resultCohort.GroupType) + }) + + t.Run("test_group_cohort_request_status_429s_keep_retrying", func(t *testing.T) { + cohort := &Cohort{Id: "1234", LastModified: 0, Size: 1, MemberIds: []string{"group"}, GroupType: "org name"} + response := &cohortInfo{Id: "1234", LastModified: 0, Size: 1, MemberIds: []string{"group"}, GroupType: "org name"} + + for i := 0; i < 9; i++ { + httpmock.RegisterResponder("GET", api.buildCohortURL("1234", cohort), + httpmock.NewStringResponder(429, ""), + ) + } + httpmock.RegisterResponder("GET", api.buildCohortURL("1234", cohort), + func(req *http.Request) (*http.Response, error) { + resp, err := httpmock.NewJsonResponse(200, response) + if err != nil { + return httpmock.NewStringResponse(500, ""), nil + } + return resp, nil + }, + ) + + resultCohort, err := api.getCohort("1234", cohort) + assert.NoError(t, err) + assert.Equal(t, cohort.Id, resultCohort.Id) + assert.Equal(t, cohort.LastModified, resultCohort.LastModified) + assert.Equal(t, cohort.Size, resultCohort.Size) + assert.Equal(t, cohort.MemberIds, resultCohort.MemberIds) + assert.Equal(t, cohort.GroupType, resultCohort.GroupType) + }) + + t.Run("test_cohort_size_too_large", func(t *testing.T) { + cohort := &Cohort{Id: "1234", LastModified: 0, Size: 16000, MemberIds: []string{}} + + httpmock.RegisterResponder("GET", api.buildCohortURL("1234", cohort), + httpmock.NewStringResponder(413, ""), + ) + + _, err := api.getCohort("1234", cohort) + assert.Error(t, err) + _, isCohortTooLargeException := err.(*cohortTooLargeException) + assert.True(t, isCohortTooLargeException) + }) + + t.Run("test_cohort_not_modified", func(t *testing.T) { + cohort := &Cohort{Id: "1234", LastModified: 1000, Size: 1, MemberIds: []string{}} + + httpmock.RegisterResponder("GET", api.buildCohortURL("1234", cohort), + httpmock.NewStringResponder(204, ""), + ) + + result, err := api.getCohort("1234", cohort) + assert.Nil(t, result) + assert.NoError(t, err) + }) +} diff --git a/pkg/experiment/local/cohort_loader.go b/pkg/experiment/local/cohort_loader.go new file mode 100644 index 0000000..d325315 --- /dev/null +++ b/pkg/experiment/local/cohort_loader.go @@ -0,0 +1,88 @@ +package local + +import ( + "sync" + "sync/atomic" +) + +type cohortLoader struct { + cohortDownloadApi cohortDownloadApi + cohortStorage cohortStorage + jobs sync.Map + executor *sync.Pool + lockJobs sync.Mutex +} + +func newCohortLoader(cohortDownloadApi cohortDownloadApi, cohortStorage cohortStorage) *cohortLoader { + return &cohortLoader{ + cohortDownloadApi: cohortDownloadApi, + cohortStorage: cohortStorage, + executor: &sync.Pool{ + New: func() interface{} { + return &CohortLoaderTask{} + }, + }, + } +} + +func (cl *cohortLoader) loadCohort(cohortId string) *CohortLoaderTask { + cl.lockJobs.Lock() + defer cl.lockJobs.Unlock() + + task, ok := cl.jobs.Load(cohortId) + if !ok { + task = cl.executor.Get().(*CohortLoaderTask) + task.(*CohortLoaderTask).init(cl, cohortId) + cl.jobs.Store(cohortId, task) + go task.(*CohortLoaderTask).run() + } + + return task.(*CohortLoaderTask) +} + +func (cl *cohortLoader) removeJob(cohortId string) { + cl.jobs.Delete(cohortId) +} + +type CohortLoaderTask struct { + loader *cohortLoader + cohortId string + done int32 + doneChan chan struct{} + err error +} + +func (task *CohortLoaderTask) init(loader *cohortLoader, cohortId string) { + task.loader = loader + task.cohortId = cohortId + task.done = 0 + task.doneChan = make(chan struct{}) + task.err = nil +} + +func (task *CohortLoaderTask) run() { + defer task.loader.executor.Put(task) + + cohort, err := task.loader.downloadCohort(task.cohortId) + if err != nil { + task.err = err + } else { + if cohort != nil { + task.loader.cohortStorage.putCohort(cohort) + } + } + + task.loader.removeJob(task.cohortId) + atomic.StoreInt32(&task.done, 1) + close(task.doneChan) +} + +func (task *CohortLoaderTask) wait() error { + <-task.doneChan + return task.err +} + +func (cl *cohortLoader) downloadCohort(cohortID string) (*Cohort, error) { + cohort := cl.cohortStorage.getCohort(cohortID) + return cl.cohortDownloadApi.getCohort(cohortID, cohort) +} diff --git a/pkg/experiment/local/cohort_loader_test.go b/pkg/experiment/local/cohort_loader_test.go new file mode 100644 index 0000000..4f483e1 --- /dev/null +++ b/pkg/experiment/local/cohort_loader_test.go @@ -0,0 +1,120 @@ +package local + +import ( + "errors" + "testing" + + "github.com/stretchr/testify/mock" +) + +func TestLoadSuccess(t *testing.T) { + api := &MockCohortDownloadApi{} + storage := newInMemoryCohortStorage() + loader := newCohortLoader(api, storage) + + // Define mock behavior + api.On("getCohort", "a", mock.AnythingOfType("*local.Cohort")).Return(&Cohort{Id: "a", LastModified: 0, Size: 1, MemberIds: []string{"1"}, GroupType: userGroupType}, nil) + api.On("getCohort", "b", mock.AnythingOfType("*local.Cohort")).Return(&Cohort{Id: "b", LastModified: 0, Size: 2, MemberIds: []string{"1", "2"}, GroupType: userGroupType}, nil) + + futureA := loader.loadCohort("a") + futureB := loader.loadCohort("b") + + if err := futureA.wait(); err != nil { + t.Errorf("futureA.wait() returned error: %v", err) + } + if err := futureB.wait(); err != nil { + t.Errorf("futureB.wait() returned error: %v", err) + } + + storageDescriptionA := storage.getCohort("a") + storageDescriptionB := storage.getCohort("b") + expectedA := &Cohort{Id: "a", LastModified: 0, Size: 1, MemberIds: []string{"1"}, GroupType: userGroupType} + expectedB := &Cohort{Id: "b", LastModified: 0, Size: 2, MemberIds: []string{"1", "2"}, GroupType: userGroupType} + + if !CohortEquals(storageDescriptionA, expectedA) { + t.Errorf("Unexpected cohort A stored: %+v", storageDescriptionA) + } + if !CohortEquals(storageDescriptionB, expectedB) { + t.Errorf("Unexpected cohort B stored: %+v", storageDescriptionB) + } + + storageUser1Cohorts := storage.getCohortsForUser("1", map[string]struct{}{"a": {}, "b": {}}) + storageUser2Cohorts := storage.getCohortsForUser("2", map[string]struct{}{"a": {}, "b": {}}) + if len(storageUser1Cohorts) != 2 || len(storageUser2Cohorts) != 1 { + t.Errorf("Unexpected user cohorts: User1: %+v, User2: %+v", storageUser1Cohorts, storageUser2Cohorts) + } +} + +func TestFilterCohortsAlreadyComputed(t *testing.T) { + api := &MockCohortDownloadApi{} + storage := newInMemoryCohortStorage() + loader := newCohortLoader(api, storage) + + storage.putCohort(&Cohort{Id: "a", LastModified: 0, Size: 0, MemberIds: []string{}}) + storage.putCohort(&Cohort{Id: "b", LastModified: 0, Size: 0, MemberIds: []string{}}) + + // Define mock behavior + api.On("getCohort", "a", mock.AnythingOfType("*local.Cohort")).Return(&Cohort{Id: "a", LastModified: 0, Size: 0, MemberIds: []string{}, GroupType: userGroupType}, nil) + api.On("getCohort", "b", mock.AnythingOfType("*local.Cohort")).Return(&Cohort{Id: "b", LastModified: 1, Size: 2, MemberIds: []string{"1", "2"}, GroupType: userGroupType}, nil) + + futureA := loader.loadCohort("a") + futureB := loader.loadCohort("b") + + if err := futureA.wait(); err != nil { + t.Errorf("futureA.wait() returned error: %v", err) + } + if err := futureB.wait(); err != nil { + t.Errorf("futureB.wait() returned error: %v", err) + } + + storageDescriptionA := storage.getCohort("a") + storageDescriptionB := storage.getCohort("b") + expectedA := &Cohort{Id: "a", LastModified: 0, Size: 0, MemberIds: []string{}, GroupType: userGroupType} + expectedB := &Cohort{Id: "b", LastModified: 1, Size: 2, MemberIds: []string{"1", "2"}, GroupType: userGroupType} + + if !CohortEquals(storageDescriptionA, expectedA) { + t.Errorf("Unexpected cohort A stored: %+v", storageDescriptionA) + } + if !CohortEquals(storageDescriptionB, expectedB) { + t.Errorf("Unexpected cohort B stored: %+v", storageDescriptionB) + } + + storageUser1Cohorts := storage.getCohortsForUser("1", map[string]struct{}{"a": {}, "b": {}}) + storageUser2Cohorts := storage.getCohortsForUser("2", map[string]struct{}{"a": {}, "b": {}}) + if len(storageUser1Cohorts) != 1 || len(storageUser2Cohorts) != 1 { + t.Errorf("Unexpected user cohorts: User1: %+v, User2: %+v", storageUser1Cohorts, storageUser2Cohorts) + } +} + +func TestLoadDownloadFailureThrows(t *testing.T) { + api := &MockCohortDownloadApi{} + storage := newInMemoryCohortStorage() + loader := newCohortLoader(api, storage) + + // Define mock behavior + api.On("getCohort", "a", mock.AnythingOfType("*local.Cohort")).Return(&Cohort{Id: "a", LastModified: 0, Size: 1, MemberIds: []string{"1"}, GroupType: userGroupType}, nil) + api.On("getCohort", "b", mock.AnythingOfType("*local.Cohort")).Return(nil, errors.New("connection timed out")) + api.On("getCohort", "c", mock.AnythingOfType("*local.Cohort")).Return(&Cohort{Id: "c", LastModified: 0, Size: 1, MemberIds: []string{"1"}, GroupType: userGroupType}, nil) + + futureA := loader.loadCohort("a") + errB := loader.loadCohort("b").wait() + futureC := loader.loadCohort("c") + + if err := futureA.wait(); err != nil { + t.Errorf("futureA.wait() returned error: %v", err) + } + + if errB == nil || errB.Error() != "connection timed out" { + t.Errorf("futureB.wait() expected 'Connection timed out' error, got: %v", errB) + } + + if err := futureC.wait(); err != nil { + t.Errorf("futureC.wait() returned error: %v", err) + } + + expectedCohorts := map[string]struct{}{"a": {}, "c": {}} + actualCohorts := storage.getCohortsForUser("1", map[string]struct{}{"a": {}, "b": {}, "c": {}}) + if len(actualCohorts) != len(expectedCohorts) { + t.Errorf("Expected cohorts for user '1': %+v, but got: %+v", expectedCohorts, actualCohorts) + } +} diff --git a/pkg/experiment/local/cohort_storage.go b/pkg/experiment/local/cohort_storage.go new file mode 100644 index 0000000..7a81f84 --- /dev/null +++ b/pkg/experiment/local/cohort_storage.go @@ -0,0 +1,106 @@ +package local + +import ( + "sync" +) + +type cohortStorage interface { + getCohort(cohortID string) *Cohort + getCohorts() map[string]*Cohort + getCohortsForUser(userID string, cohortIDs map[string]struct{}) map[string]struct{} + getCohortsForGroup(groupType, groupName string, cohortIDs map[string]struct{}) map[string]struct{} + putCohort(cohort *Cohort) + deleteCohort(groupType, cohortID string) + getCohortIds() map[string]struct{} +} + +type inMemoryCohortStorage struct { + lock sync.RWMutex + groupToCohortStore map[string]map[string]struct{} + cohortStore map[string]*Cohort +} + +func newInMemoryCohortStorage() *inMemoryCohortStorage { + return &inMemoryCohortStorage{ + groupToCohortStore: make(map[string]map[string]struct{}), + cohortStore: make(map[string]*Cohort), + } +} + +func (s *inMemoryCohortStorage) getCohort(cohortID string) *Cohort { + s.lock.RLock() + defer s.lock.RUnlock() + return s.cohortStore[cohortID] +} + +func (s *inMemoryCohortStorage) getCohorts() map[string]*Cohort { + s.lock.RLock() + defer s.lock.RUnlock() + cohorts := make(map[string]*Cohort) + for id, cohort := range s.cohortStore { + cohorts[id] = cohort + } + return cohorts +} + +func (s *inMemoryCohortStorage) getCohortsForUser(userID string, cohortIDs map[string]struct{}) map[string]struct{} { + return s.getCohortsForGroup(userGroupType, userID, cohortIDs) +} + +func (s *inMemoryCohortStorage) getCohortsForGroup(groupType, groupName string, cohortIDs map[string]struct{}) map[string]struct{} { + result := make(map[string]struct{}) + s.lock.RLock() + defer s.lock.RUnlock() + + groupTypeCohorts, groupExists := s.groupToCohortStore[groupType] + if !groupExists { + return result + } + + for cohortID := range cohortIDs { + if _, exists := groupTypeCohorts[cohortID]; exists { + if cohort, found := s.cohortStore[cohortID]; found { + for _, memberID := range cohort.MemberIds { + if memberID == groupName { + result[cohortID] = struct{}{} + break + } + } + } + } + } + + return result +} + +func (s *inMemoryCohortStorage) putCohort(cohort *Cohort) { + s.lock.Lock() + defer s.lock.Unlock() + if _, exists := s.groupToCohortStore[cohort.GroupType]; !exists { + s.groupToCohortStore[cohort.GroupType] = make(map[string]struct{}) + } + s.groupToCohortStore[cohort.GroupType][cohort.Id] = struct{}{} + s.cohortStore[cohort.Id] = cohort +} + +func (s *inMemoryCohortStorage) deleteCohort(groupType, cohortID string) { + s.lock.Lock() + defer s.lock.Unlock() + if groupCohorts, exists := s.groupToCohortStore[groupType]; exists { + delete(groupCohorts, cohortID) + if len(groupCohorts) == 0 { + delete(s.groupToCohortStore, groupType) + } + } + delete(s.cohortStore, cohortID) +} + +func (s *inMemoryCohortStorage) getCohortIds() map[string]struct{} { + s.lock.RLock() + defer s.lock.RUnlock() + cohortIds := make(map[string]struct{}) + for id := range s.cohortStore { + cohortIds[id] = struct{}{} + } + return cohortIds +} diff --git a/pkg/experiment/local/config.go b/pkg/experiment/local/config.go index 9c9cb9d..c896fd7 100644 --- a/pkg/experiment/local/config.go +++ b/pkg/experiment/local/config.go @@ -2,15 +2,28 @@ package local import ( "github.com/amplitude/analytics-go/amplitude" + "math" "time" ) +const EUFlagServerUrl = "https://flag.lab.eu.amplitude.com" +const EUCohortSyncUrl = "https://cohort-v2.lab.eu.amplitude.com" + +type ServerZone int + +const ( + USServerZone ServerZone = iota + EUServerZone +) + type Config struct { Debug bool ServerUrl string + ServerZone ServerZone FlagConfigPollerInterval time.Duration FlagConfigPollerRequestTimeout time.Duration AssignmentConfig *AssignmentConfig + CohortSyncConfig *CohortSyncConfig } type AssignmentConfig struct { @@ -18,9 +31,18 @@ type AssignmentConfig struct { CacheCapacity int } +type CohortSyncConfig struct { + ApiKey string + SecretKey string + MaxCohortSize int + CohortPollingInterval time.Duration + CohortServerUrl string +} + var DefaultConfig = &Config{ Debug: false, ServerUrl: "https://api.lab.amplitude.com/", + ServerZone: USServerZone, FlagConfigPollerInterval: 30 * time.Second, FlagConfigPollerRequestTimeout: 10 * time.Second, } @@ -29,13 +51,28 @@ var DefaultAssignmentConfig = &AssignmentConfig{ CacheCapacity: 524288, } +var DefaultCohortSyncConfig = &CohortSyncConfig{ + MaxCohortSize: math.MaxInt32, + CohortPollingInterval: 60 * time.Second, + CohortServerUrl: "https://cohort-v2.lab.amplitude.com", +} + func fillConfigDefaults(c *Config) *Config { if c == nil { return DefaultConfig } + if c.ServerZone == 0 { + c.ServerZone = DefaultConfig.ServerZone + } if c.ServerUrl == "" { - c.ServerUrl = DefaultConfig.ServerUrl + switch c.ServerZone { + case USServerZone: + c.ServerUrl = DefaultConfig.ServerUrl + case EUServerZone: + c.ServerUrl = EUFlagServerUrl + } } + if c.FlagConfigPollerInterval == 0 { c.FlagConfigPollerInterval = DefaultConfig.FlagConfigPollerInterval } @@ -45,5 +82,23 @@ func fillConfigDefaults(c *Config) *Config { if c.AssignmentConfig != nil && c.AssignmentConfig.CacheCapacity == 0 { c.AssignmentConfig.CacheCapacity = DefaultAssignmentConfig.CacheCapacity } + + if c.CohortSyncConfig != nil && c.CohortSyncConfig.MaxCohortSize == 0 { + c.CohortSyncConfig.MaxCohortSize = DefaultCohortSyncConfig.MaxCohortSize + } + + if c.CohortSyncConfig != nil && (c.CohortSyncConfig.CohortPollingInterval < DefaultCohortSyncConfig.CohortPollingInterval) { + c.CohortSyncConfig.CohortPollingInterval = DefaultCohortSyncConfig.CohortPollingInterval + } + + if c.CohortSyncConfig != nil && c.CohortSyncConfig.CohortServerUrl == "" { + switch c.ServerZone { + case USServerZone: + c.CohortSyncConfig.CohortServerUrl = DefaultCohortSyncConfig.CohortServerUrl + case EUServerZone: + c.CohortSyncConfig.CohortServerUrl = EUCohortSyncUrl + } + } + return c } diff --git a/pkg/experiment/local/config_test.go b/pkg/experiment/local/config_test.go new file mode 100644 index 0000000..6c790e7 --- /dev/null +++ b/pkg/experiment/local/config_test.go @@ -0,0 +1,174 @@ +package local + +import ( + "testing" + "time" +) + +func TestFillConfigDefaults_ServerZoneAndServerUrl(t *testing.T) { + tests := []struct { + name string + input *Config + expectedZone ServerZone + expectedUrl string + }{ + { + name: "Nil config", + input: nil, + expectedZone: DefaultConfig.ServerZone, + expectedUrl: DefaultConfig.ServerUrl, + }, + { + name: "Empty ServerZone", + input: &Config{}, + expectedZone: DefaultConfig.ServerZone, + expectedUrl: DefaultConfig.ServerUrl, + }, + { + name: "ServerZone US", + input: &Config{ServerZone: USServerZone}, + expectedZone: USServerZone, + expectedUrl: DefaultConfig.ServerUrl, + }, + { + name: "ServerZone EU", + input: &Config{ServerZone: EUServerZone}, + expectedZone: EUServerZone, + expectedUrl: EUFlagServerUrl, + }, + { + name: "Uppercase ServerZone EU", + input: &Config{ServerZone: EUServerZone}, + expectedZone: EUServerZone, + expectedUrl: EUFlagServerUrl, + }, + { + name: "Custom ServerUrl", + input: &Config{ServerZone: USServerZone, ServerUrl: "https://custom.url/"}, + expectedZone: USServerZone, + expectedUrl: "https://custom.url/", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := fillConfigDefaults(tt.input) + if result.ServerZone != tt.expectedZone { + t.Errorf("expected ServerZone %d, got %d", tt.expectedZone, result.ServerZone) + } + if result.ServerUrl != tt.expectedUrl { + t.Errorf("expected ServerUrl %s, got %s", tt.expectedUrl, result.ServerUrl) + } + }) + } +} + +func TestFillConfigDefaults_CohortSyncConfig(t *testing.T) { + tests := []struct { + name string + input *Config + expectedUrl string + }{ + { + name: "Nil CohortSyncConfig", + input: &Config{ + ServerZone: EUServerZone, + CohortSyncConfig: nil, + }, + expectedUrl: "", + }, + { + name: "CohortSyncConfig with empty CohortServerUrl", + input: &Config{ + ServerZone: EUServerZone, + CohortSyncConfig: &CohortSyncConfig{}, + }, + expectedUrl: EUCohortSyncUrl, + }, + { + name: "CohortSyncConfig with custom CohortServerUrl", + input: &Config{ + ServerZone: USServerZone, + CohortSyncConfig: &CohortSyncConfig{ + CohortServerUrl: "https://custom-cohort.url/", + }, + }, + expectedUrl: "https://custom-cohort.url/", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := fillConfigDefaults(tt.input) + if tt.input.CohortSyncConfig == nil { + if result.CohortSyncConfig == nil { + return + } + if result.CohortSyncConfig.CohortServerUrl != tt.expectedUrl { + t.Errorf("expected CohortServerUrl %s, got %s", tt.expectedUrl, result.CohortSyncConfig.CohortServerUrl) + } + } else { + if result.CohortSyncConfig.CohortServerUrl != tt.expectedUrl { + t.Errorf("expected CohortServerUrl %s, got %s", tt.expectedUrl, result.CohortSyncConfig.CohortServerUrl) + } + } + }) + } +} + +func TestFillConfigDefaults_DefaultValues(t *testing.T) { + tests := []struct { + name string + input *Config + expected *Config + }{ + { + name: "Nil config", + input: nil, + expected: DefaultConfig, + }, + { + name: "Empty config", + input: &Config{}, + expected: &Config{ + ServerZone: DefaultConfig.ServerZone, + ServerUrl: DefaultConfig.ServerUrl, + FlagConfigPollerInterval: DefaultConfig.FlagConfigPollerInterval, + FlagConfigPollerRequestTimeout: DefaultConfig.FlagConfigPollerRequestTimeout, + }, + }, + { + name: "Custom values", + input: &Config{ + ServerZone: EUServerZone, + ServerUrl: "https://custom.url/", + FlagConfigPollerInterval: 60 * time.Second, + FlagConfigPollerRequestTimeout: 20 * time.Second, + }, + expected: &Config{ + ServerZone: EUServerZone, + ServerUrl: "https://custom.url/", + FlagConfigPollerInterval: 60 * time.Second, + FlagConfigPollerRequestTimeout: 20 * time.Second, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := fillConfigDefaults(tt.input) + if result.ServerZone != tt.expected.ServerZone { + t.Errorf("expected ServerZone %d, got %d", tt.expected.ServerZone, result.ServerZone) + } + if result.ServerUrl != tt.expected.ServerUrl { + t.Errorf("expected ServerUrl %s, got %s", tt.expected.ServerUrl, result.ServerUrl) + } + if result.FlagConfigPollerInterval != tt.expected.FlagConfigPollerInterval { + t.Errorf("expected FlagConfigPollerInterval %v, got %v", tt.expected.FlagConfigPollerInterval, result.FlagConfigPollerInterval) + } + if result.FlagConfigPollerRequestTimeout != tt.expected.FlagConfigPollerRequestTimeout { + t.Errorf("expected FlagConfigPollerRequestTimeout %v, got %v", tt.expected.FlagConfigPollerRequestTimeout, result.FlagConfigPollerRequestTimeout) + } + }) + } +} diff --git a/pkg/experiment/local/deployment_runner.go b/pkg/experiment/local/deployment_runner.go new file mode 100644 index 0000000..2b85437 --- /dev/null +++ b/pkg/experiment/local/deployment_runner.go @@ -0,0 +1,196 @@ +package local + +import ( + "fmt" + "github.com/amplitude/experiment-go-server/internal/evaluation" + "github.com/amplitude/experiment-go-server/internal/logger" + "strings" + "sync" +) + +type deploymentRunner struct { + config *Config + flagConfigApi flagConfigApi + flagConfigStorage flagConfigStorage + cohortStorage cohortStorage + cohortLoader *cohortLoader + lock sync.Mutex + poller *poller + log *logger.Log +} + +func newDeploymentRunner( + config *Config, + flagConfigApi flagConfigApi, + flagConfigStorage flagConfigStorage, + cohortStorage cohortStorage, + cohortLoader *cohortLoader, +) *deploymentRunner { + dr := &deploymentRunner{ + config: config, + flagConfigApi: flagConfigApi, + flagConfigStorage: flagConfigStorage, + cohortStorage: cohortStorage, + cohortLoader: cohortLoader, + log: logger.New(config.Debug), + } + dr.poller = newPoller() + return dr +} + +func (dr *deploymentRunner) start() error { + dr.lock.Lock() + defer dr.lock.Unlock() + + if err := dr.updateFlagConfigs(); err != nil { + dr.log.Error("Initial updateFlagConfigs failed: %v", err) + return err + } + + dr.poller.Poll(dr.config.FlagConfigPollerInterval, func() { + if err := dr.periodicRefresh(); err != nil { + dr.log.Error("Periodic updateFlagConfigs failed: %v", err) + } + }) + + if dr.config.CohortSyncConfig != nil { + dr.poller.Poll(dr.config.CohortSyncConfig.CohortPollingInterval, func() { + dr.updateStoredCohorts() + }) + } + return nil +} + +func (dr *deploymentRunner) periodicRefresh() error { + defer func() { + if r := recover(); r != nil { + dr.log.Error("Recovered in periodicRefresh: %v", r) + } + }() + return dr.updateFlagConfigs() +} + +func (dr *deploymentRunner) updateFlagConfigs() error { + dr.log.Debug("Refreshing flag configs.") + flagConfigs, err := dr.flagConfigApi.getFlagConfigs() + if err != nil { + dr.log.Error("Failed to fetch flag configs: %v", err) + return err + } + + flagKeys := make(map[string]struct{}) + for _, flag := range flagConfigs { + flagKeys[flag.Key] = struct{}{} + } + + dr.flagConfigStorage.removeIf(func(f *evaluation.Flag) bool { + _, exists := flagKeys[f.Key] + return !exists + }) + + if dr.cohortLoader == nil { + for _, flagConfig := range flagConfigs { + dr.log.Debug("Putting non-cohort flag %s", flagConfig.Key) + dr.flagConfigStorage.putFlagConfig(flagConfig) + } + return nil + } + + newCohortIDs := make(map[string]struct{}) + for _, flagConfig := range flagConfigs { + for cohortID := range getAllCohortIDsFromFlag(flagConfig) { + newCohortIDs[cohortID] = struct{}{} + } + } + + existingCohortIDs := dr.cohortStorage.getCohortIds() + cohortIDsToDownload := difference(newCohortIDs, existingCohortIDs) + + // Download all new cohorts + dr.downloadCohorts(cohortIDsToDownload) + + // Get updated set of cohort ids + updatedCohortIDs := dr.cohortStorage.getCohortIds() + // Iterate through new flag configs and check if their required cohorts exist + for _, flagConfig := range flagConfigs { + cohortIDs := getAllCohortIDsFromFlag(flagConfig) + missingCohorts := difference(cohortIDs, updatedCohortIDs) + + dr.flagConfigStorage.putFlagConfig(flagConfig) + dr.log.Debug("Putting flag %s", flagConfig.Key) + if len(missingCohorts) != 0 { + dr.log.Error("Flag %s - failed to load cohorts: %v", flagConfig.Key, missingCohorts) + } + } + + // Delete unused cohorts + dr.deleteUnusedCohorts() + dr.log.Debug("Refreshed %d flag configs.", len(flagConfigs)) + + return nil +} + +func (dr *deploymentRunner) updateStoredCohorts() { + cohortIDs := getAllCohortIDsFromFlags(dr.flagConfigStorage.getFlagConfigsArray()) + dr.downloadCohorts(cohortIDs) +} + +func (dr *deploymentRunner) deleteUnusedCohorts() { + flagCohortIDs := make(map[string]struct{}) + for _, flag := range dr.flagConfigStorage.getFlagConfigs() { + for cohortID := range getAllCohortIDsFromFlag(flag) { + flagCohortIDs[cohortID] = struct{}{} + } + } + + storageCohorts := dr.cohortStorage.getCohorts() + for cohortID := range storageCohorts { + if _, exists := flagCohortIDs[cohortID]; !exists { + cohort := storageCohorts[cohortID] + if cohort != nil { + dr.cohortStorage.deleteCohort(cohort.GroupType, cohortID) + } + } + } +} + +func difference(set1, set2 map[string]struct{}) map[string]struct{} { + diff := make(map[string]struct{}) + for k := range set1 { + if _, exists := set2[k]; !exists { + diff[k] = struct{}{} + } + } + return diff +} + +func (dr *deploymentRunner) downloadCohorts(cohortIDs map[string]struct{}) { + var wg sync.WaitGroup + errorChan := make(chan error, len(cohortIDs)) + + for cohortID := range cohortIDs { + wg.Add(1) + go func(id string) { + defer wg.Done() + task := dr.cohortLoader.loadCohort(id) + if err := task.wait(); err != nil { + errorChan <- fmt.Errorf("cohort %s: %v", id, err) + } + }(cohortID) + } + + go func() { + wg.Wait() + close(errorChan) + }() + + var errorMessages []string + for err := range errorChan { + errorMessages = append(errorMessages, err.Error()) + dr.log.Error("Error downloading cohort: %v", err) + } + + if len(errorMessages) > 0 { + dr.log.Error("One or more cohorts failed to download:\n%s", strings.Join(errorMessages, "\n")) + } +} diff --git a/pkg/experiment/local/deployment_runner_test.go b/pkg/experiment/local/deployment_runner_test.go new file mode 100644 index 0000000..691dae7 --- /dev/null +++ b/pkg/experiment/local/deployment_runner_test.go @@ -0,0 +1,105 @@ +package local + +import ( + "errors" + "fmt" + "testing" + + "github.com/amplitude/experiment-go-server/internal/evaluation" +) + +const ( + CohortId = "1234" +) + +func TestStartThrowsIfFirstFlagConfigLoadFails(t *testing.T) { + flagAPI := &mockFlagConfigApi{getFlagConfigsFunc: func() (map[string]*evaluation.Flag, error) { + return nil, errors.New("test") + }} + cohortDownloadAPI := &mockCohortDownloadApi{} + flagConfigStorage := newInMemoryFlagConfigStorage() + cohortStorage := newInMemoryCohortStorage() + cohortLoader := newCohortLoader(cohortDownloadAPI, cohortStorage) + + runner := newDeploymentRunner( + &Config{}, + flagAPI, + flagConfigStorage, + cohortStorage, + cohortLoader, + ) + + err := runner.start() + + if err == nil { + t.Error("Expected error but got nil") + } +} + +func TestStartSucceedsEvenIfFirstCohortLoadFails(t *testing.T) { + flagAPI := &mockFlagConfigApi{getFlagConfigsFunc: func() (map[string]*evaluation.Flag, error) { + return map[string]*evaluation.Flag{"flag": createTestFlag()}, nil + }} + cohortDownloadAPI := &mockCohortDownloadApi{getCohortFunc: func(cohortID string, cohort *Cohort) (*Cohort, error) { + return nil, errors.New("test") + }} + flagConfigStorage := newInMemoryFlagConfigStorage() + cohortStorage := newInMemoryCohortStorage() + cohortLoader := newCohortLoader(cohortDownloadAPI, cohortStorage) + + runner := newDeploymentRunner( + DefaultConfig, + flagAPI, + flagConfigStorage, + cohortStorage, + cohortLoader, + ) + + err := runner.start() + + if err != nil { + t.Errorf("Expected no error but got %v", err) + } +} + +type mockFlagConfigApi struct { + getFlagConfigsFunc func() (map[string]*evaluation.Flag, error) +} + +func (m *mockFlagConfigApi) getFlagConfigs() (map[string]*evaluation.Flag, error) { + if m.getFlagConfigsFunc != nil { + return m.getFlagConfigsFunc() + } + return nil, fmt.Errorf("mock not implemented") +} + +type mockCohortDownloadApi struct { + getCohortFunc func(cohortID string, cohort *Cohort) (*Cohort, error) +} + +func (m *mockCohortDownloadApi) getCohort(cohortID string, cohort *Cohort) (*Cohort, error) { + if m.getCohortFunc != nil { + return m.getCohortFunc(cohortID, cohort) + } + return nil, fmt.Errorf("mock not implemented") +} + +func createTestFlag() *evaluation.Flag { + return &evaluation.Flag{ + Key: "flag", + Variants: map[string]*evaluation.Variant{}, + Segments: []*evaluation.Segment{ + { + Conditions: [][]*evaluation.Condition{ + { + { + Selector: []string{"context", "user", "cohort_ids"}, + Op: "set contains any", + Values: []string{CohortId}, + }, + }, + }, + }, + }, + } +} diff --git a/pkg/experiment/local/exception.go b/pkg/experiment/local/exception.go new file mode 100644 index 0000000..8e8dba6 --- /dev/null +++ b/pkg/experiment/local/exception.go @@ -0,0 +1,18 @@ +package local + +type httpErrorResponseException struct { + StatusCode int + Message string +} + +func (e *httpErrorResponseException) Error() string { + return e.Message +} + +type cohortTooLargeException struct { + Message string +} + +func (e *cohortTooLargeException) Error() string { + return e.Message +} diff --git a/pkg/experiment/local/flag_config_api.go b/pkg/experiment/local/flag_config_api.go new file mode 100644 index 0000000..b8dc324 --- /dev/null +++ b/pkg/experiment/local/flag_config_api.go @@ -0,0 +1,70 @@ +package local + +import ( + "context" + "encoding/json" + "fmt" + "github.com/amplitude/experiment-go-server/internal/evaluation" + "github.com/amplitude/experiment-go-server/pkg/experiment" + "io/ioutil" + "net/http" + "net/url" + "time" +) + +type flagConfigApi interface { + getFlagConfigs() (map[string]*evaluation.Flag, error) +} + +type flagConfigApiV2 struct { + DeploymentKey string + ServerURL string + FlagConfigPollerRequestTimeoutMillis time.Duration +} + +func newFlagConfigApiV2(deploymentKey, serverURL string, flagConfigPollerRequestTimeoutMillis time.Duration) *flagConfigApiV2 { + return &flagConfigApiV2{ + DeploymentKey: deploymentKey, + ServerURL: serverURL, + FlagConfigPollerRequestTimeoutMillis: flagConfigPollerRequestTimeoutMillis, + } +} + +func (a *flagConfigApiV2) getFlagConfigs() (map[string]*evaluation.Flag, error) { + client := &http.Client{} + endpoint, err := url.Parse(a.ServerURL) + if err != nil { + return nil, err + } + endpoint.Path = "sdk/v2/flags" + endpoint.RawQuery = "v=0" + ctx, cancel := context.WithTimeout(context.Background(), a.FlagConfigPollerRequestTimeoutMillis) + defer cancel() + req, err := http.NewRequest("GET", endpoint.String(), nil) + if err != nil { + return nil, err + } + req = req.WithContext(ctx) + req.Header.Set("Authorization", fmt.Sprintf("Api-Key %s", a.DeploymentKey)) + req.Header.Set("Content-Type", "application/json; charset=UTF-8") + req.Header.Set("X-Amp-Exp-Library", fmt.Sprintf("experiment-go-server/%v", experiment.VERSION)) + resp, err := client.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + return nil, err + } + var flagsArray []*evaluation.Flag + err = json.Unmarshal(body, &flagsArray) + if err != nil { + return nil, err + } + flags := make(map[string]*evaluation.Flag) + for _, flag := range flagsArray { + flags[flag.Key] = flag + } + return flags, nil +} diff --git a/pkg/experiment/local/flag_config_storage.go b/pkg/experiment/local/flag_config_storage.go new file mode 100644 index 0000000..02daea6 --- /dev/null +++ b/pkg/experiment/local/flag_config_storage.go @@ -0,0 +1,68 @@ +package local + +import ( + "github.com/amplitude/experiment-go-server/internal/evaluation" + "sync" +) + +type flagConfigStorage interface { + getFlagConfig(key string) *evaluation.Flag + getFlagConfigs() map[string]*evaluation.Flag + getFlagConfigsArray() []*evaluation.Flag + putFlagConfig(flagConfig *evaluation.Flag) + removeIf(condition func(*evaluation.Flag) bool) +} + +type inMemoryFlagConfigStorage struct { + flagConfigs map[string]*evaluation.Flag + flagConfigsLock sync.Mutex +} + +func newInMemoryFlagConfigStorage() *inMemoryFlagConfigStorage { + return &inMemoryFlagConfigStorage{ + flagConfigs: make(map[string]*evaluation.Flag), + } +} + +func (storage *inMemoryFlagConfigStorage) getFlagConfig(key string) *evaluation.Flag { + storage.flagConfigsLock.Lock() + defer storage.flagConfigsLock.Unlock() + return storage.flagConfigs[key] +} + +func (storage *inMemoryFlagConfigStorage) getFlagConfigs() map[string]*evaluation.Flag { + storage.flagConfigsLock.Lock() + defer storage.flagConfigsLock.Unlock() + copyFlagConfigs := make(map[string]*evaluation.Flag) + for key, value := range storage.flagConfigs { + copyFlagConfigs[key] = value + } + return copyFlagConfigs +} + +func (storage *inMemoryFlagConfigStorage) getFlagConfigsArray() []*evaluation.Flag { + storage.flagConfigsLock.Lock() + defer storage.flagConfigsLock.Unlock() + + var copyFlagConfigs []*evaluation.Flag + for _, value := range storage.flagConfigs { + copyFlagConfigs = append(copyFlagConfigs, value) + } + return copyFlagConfigs +} + +func (storage *inMemoryFlagConfigStorage) putFlagConfig(flagConfig *evaluation.Flag) { + storage.flagConfigsLock.Lock() + defer storage.flagConfigsLock.Unlock() + storage.flagConfigs[flagConfig.Key] = flagConfig +} + +func (storage *inMemoryFlagConfigStorage) removeIf(condition func(*evaluation.Flag) bool) { + storage.flagConfigsLock.Lock() + defer storage.flagConfigsLock.Unlock() + for key, value := range storage.flagConfigs { + if condition(value) { + delete(storage.flagConfigs, key) + } + } +} diff --git a/pkg/experiment/local/flag_config_test.go b/pkg/experiment/local/flag_config_test.go new file mode 100644 index 0000000..686e672 --- /dev/null +++ b/pkg/experiment/local/flag_config_test.go @@ -0,0 +1,221 @@ +package local + +import ( + "testing" + + "github.com/amplitude/experiment-go-server/internal/evaluation" + "github.com/stretchr/testify/assert" +) + +func TestGetAllCohortIDsFromFlag(t *testing.T) { + flags := getTestFlags() + expectedCohortIDs := []string{ + "cohort1", "cohort2", "cohort3", "cohort4", "cohort5", "cohort6", "cohort7", "cohort8", + } + expectedCohortIDSet := make(map[string]bool) + for _, id := range expectedCohortIDs { + expectedCohortIDSet[id] = true + } + + for _, flag := range flags { + cohortIDs := getAllCohortIDsFromFlag(flag) + for id := range cohortIDs { + assert.True(t, expectedCohortIDSet[id]) + } + } +} + +func TestGetGroupedCohortIDsFromFlag(t *testing.T) { + flags := getTestFlags() + expectedGroupedCohortIDs := map[string][]string{ + "User": {"cohort1", "cohort2", "cohort3", "cohort4", "cohort5", "cohort6"}, + "group_name": {"cohort7", "cohort8"}, + } + + for _, flag := range flags { + groupedCohortIDs := getGroupedCohortIDsFromFlag(flag) + for key, values := range groupedCohortIDs { + assert.Contains(t, expectedGroupedCohortIDs, key) + expectedSet := make(map[string]bool) + for _, id := range expectedGroupedCohortIDs[key] { + expectedSet[id] = true + } + for id := range values { + assert.True(t, expectedSet[id]) + } + } + } +} + +func TestGetAllCohortIDsFromFlags(t *testing.T) { + flags := getTestFlags() + expectedCohortIDs := []string{ + "cohort1", "cohort2", "cohort3", "cohort4", "cohort5", "cohort6", "cohort7", "cohort8", + } + expectedCohortIDSet := make(map[string]bool) + for _, id := range expectedCohortIDs { + expectedCohortIDSet[id] = true + } + + cohortIDs := getAllCohortIDsFromFlags(flags) + for id := range cohortIDs { + assert.True(t, expectedCohortIDSet[id]) + } +} + +func TestGetGroupedCohortIDsFromFlags(t *testing.T) { + flags := getTestFlags() + expectedGroupedCohortIDs := map[string][]string{ + "User": {"cohort1", "cohort2", "cohort3", "cohort4", "cohort5", "cohort6"}, + "group_name": {"cohort7", "cohort8"}, + } + + groupedCohortIDs := getGroupedCohortIDsFromFlags(flags) + for key, values := range groupedCohortIDs { + assert.Contains(t, expectedGroupedCohortIDs, key) + expectedSet := make(map[string]bool) + for _, id := range expectedGroupedCohortIDs[key] { + expectedSet[id] = true + } + for id := range values { + assert.True(t, expectedSet[id]) + } + } +} + +func getTestFlags() []*evaluation.Flag { + return []*evaluation.Flag{ + { + Key: "flag-1", + Metadata: map[string]interface{}{ + "deployed": true, + "evaluationMode": "local", + "flagType": "release", + "flagVersion": 1, + }, + Segments: []*evaluation.Segment{ + { + Conditions: [][]*evaluation.Condition{ + { + { + Op: "set contains any", + Selector: []string{"context", "user", "cohort_ids"}, + Values: []string{"cohort1", "cohort2"}, + }, + }, + }, + Metadata: map[string]interface{}{ + "segmentName": "Segment A", + }, + Variant: "on", + }, + { + Metadata: map[string]interface{}{ + "segmentName": "All Other Users", + }, + Variant: "off", + }, + }, + Variants: map[string]*evaluation.Variant{ + "off": { + Key: "off", + Metadata: map[string]interface{}{ + "default": true, + }, + }, + "on": { + Key: "on", + Value: "on", + }, + }, + }, + { + Key: "flag-2", + Metadata: map[string]interface{}{ + "deployed": true, + "evaluationMode": "local", + "flagType": "release", + "flagVersion": 2, + }, + Segments: []*evaluation.Segment{ + { + Conditions: [][]*evaluation.Condition{ + { + { + Op: "set contains any", + Selector: []string{"context", "user", "cohort_ids"}, + Values: []string{"cohort3", "cohort4", "cohort5", "cohort6"}, + }, + }, + }, + Metadata: map[string]interface{}{ + "segmentName": "Segment B", + }, + Variant: "on", + }, + { + Metadata: map[string]interface{}{ + "segmentName": "All Other Users", + }, + Variant: "off", + }, + }, + Variants: map[string]*evaluation.Variant{ + "off": { + Key: "off", + Metadata: map[string]interface{}{ + "default": true, + }, + }, + "on": { + Key: "on", + Value: "on", + }, + }, + }, + { + Key: "flag-3", + Metadata: map[string]interface{}{ + "deployed": true, + "evaluationMode": "local", + "flagType": "release", + "flagVersion": 3, + }, + Segments: []*evaluation.Segment{ + { + Conditions: [][]*evaluation.Condition{ + { + { + Op: "set contains any", + Selector: []string{"context", "groups", "group_name", "cohort_ids"}, + Values: []string{"cohort7", "cohort8"}, + }, + }, + }, + Metadata: map[string]interface{}{ + "segmentName": "Segment C", + }, + Variant: "on", + }, + { + Metadata: map[string]interface{}{ + "segmentName": "All Other Groups", + }, + Variant: "off", + }, + }, + Variants: map[string]*evaluation.Variant{ + "off": { + Key: "off", + Metadata: map[string]interface{}{ + "default": true, + }, + }, + "on": { + Key: "on", + Value: "on", + }, + }, + }, + } +} diff --git a/pkg/experiment/local/flag_config_util.go b/pkg/experiment/local/flag_config_util.go new file mode 100644 index 0000000..d3fc46b --- /dev/null +++ b/pkg/experiment/local/flag_config_util.go @@ -0,0 +1,106 @@ +package local + +import ( + "github.com/amplitude/experiment-go-server/internal/evaluation" +) + +func isCohortFilter(condition *evaluation.Condition) bool { + op := condition.Op + selector := condition.Selector + if len(selector) > 0 && selector[len(selector)-1] == "cohort_ids" { + return op == "set contains any" || op == "set does not contain any" + } + return false +} + +func getGroupedCohortConditionIDs(segment *evaluation.Segment) map[string]map[string]struct{} { + cohortIDs := make(map[string]map[string]struct{}) + if segment == nil { + return cohortIDs + } + + for _, outer := range segment.Conditions { + for _, condition := range outer { + if isCohortFilter(condition) { + selector := condition.Selector + if len(selector) > 2 { + contextSubtype := selector[1] + var groupType string + if contextSubtype == "user" { + groupType = userGroupType + } else if selectorContainsGroups(selector) { + groupType = selector[2] + } else { + continue + } + values := condition.Values + cohortIDs[groupType] = make(map[string]struct{}) + for _, value := range values { + cohortIDs[groupType][value] = struct{}{} + } + } + } + } + } + return cohortIDs +} + +func getGroupedCohortIDsFromFlag(flag *evaluation.Flag) map[string]map[string]struct{} { + cohortIDs := make(map[string]map[string]struct{}) + for _, segment := range flag.Segments { + for key, values := range getGroupedCohortConditionIDs(segment) { + if _, exists := cohortIDs[key]; !exists { + cohortIDs[key] = make(map[string]struct{}) + } + for id := range values { + cohortIDs[key][id] = struct{}{} + } + } + } + return cohortIDs +} + +func getAllCohortIDsFromFlag(flag *evaluation.Flag) map[string]struct{} { + cohortIDs := make(map[string]struct{}) + groupedIDs := getGroupedCohortIDsFromFlag(flag) + for _, values := range groupedIDs { + for id := range values { + cohortIDs[id] = struct{}{} + } + } + return cohortIDs +} + +func getGroupedCohortIDsFromFlags(flags []*evaluation.Flag) map[string]map[string]struct{} { + cohortIDs := make(map[string]map[string]struct{}) + for _, flag := range flags { + for key, values := range getGroupedCohortIDsFromFlag(flag) { + if _, exists := cohortIDs[key]; !exists { + cohortIDs[key] = make(map[string]struct{}) + } + for id := range values { + cohortIDs[key][id] = struct{}{} + } + } + } + return cohortIDs +} + +func getAllCohortIDsFromFlags(flags []*evaluation.Flag) map[string]struct{} { + cohortIDs := make(map[string]struct{}) + for _, flag := range flags { + for id := range getAllCohortIDsFromFlag(flag) { + cohortIDs[id] = struct{}{} + } + } + return cohortIDs +} + +func selectorContainsGroups(selector []string) bool { + for _, s := range selector { + if s == "groups" { + return true + } + } + return false +} diff --git a/pkg/experiment/types.go b/pkg/experiment/types.go index 82910b5..08ea1a1 100644 --- a/pkg/experiment/types.go +++ b/pkg/experiment/types.go @@ -3,22 +3,40 @@ package experiment const VERSION = "1.5.0" type User struct { - UserId string `json:"user_id,omitempty"` - DeviceId string `json:"device_id,omitempty"` - Country string `json:"country,omitempty"` - Region string `json:"region,omitempty"` - Dma string `json:"dma,omitempty"` - City string `json:"city,omitempty"` - Language string `json:"language,omitempty"` - Platform string `json:"platform,omitempty"` - Version string `json:"version,omitempty"` - Os string `json:"os,omitempty"` - DeviceManufacturer string `json:"device_manufacturer,omitempty"` - DeviceBrand string `json:"device_brand,omitempty"` - DeviceModel string `json:"device_model,omitempty"` - Carrier string `json:"carrier,omitempty"` - Library string `json:"library,omitempty"` - UserProperties map[string]interface{} `json:"user_properties,omitempty"` + UserId string `json:"user_id,omitempty"` + DeviceId string `json:"device_id,omitempty"` + Country string `json:"country,omitempty"` + Region string `json:"region,omitempty"` + Dma string `json:"dma,omitempty"` + City string `json:"city,omitempty"` + Language string `json:"language,omitempty"` + Platform string `json:"platform,omitempty"` + Version string `json:"version,omitempty"` + Os string `json:"os,omitempty"` + DeviceManufacturer string `json:"device_manufacturer,omitempty"` + DeviceBrand string `json:"device_brand,omitempty"` + DeviceModel string `json:"device_model,omitempty"` + Carrier string `json:"carrier,omitempty"` + Library string `json:"library,omitempty"` + UserProperties map[string]interface{} `json:"user_properties,omitempty"` + GroupProperties map[string]map[string]string `json:"group_properties,omitempty"` + Groups map[string][]string `json:"groups,omitempty"` + CohortIds map[string]struct{} `json:"cohort_ids,omitempty"` + GroupCohortIds map[string]map[string]map[string]struct{} `json:"group_cohort_ids,omitempty"` +} + +func (u *User) AddGroupCohortIds(groupType, groupName string, cohortIds map[string]struct{}) { + if u.GroupCohortIds == nil { + u.GroupCohortIds = make(map[string]map[string]map[string]struct{}) + } + + groupNames := u.GroupCohortIds[groupType] + if groupNames == nil { + groupNames = make(map[string]map[string]struct{}) + u.GroupCohortIds[groupType] = groupNames + } + + groupNames[groupName] = cohortIds } type Variant struct {