diff --git a/.gitignore b/.gitignore index 07cba02..2ef1778 100644 --- a/.gitignore +++ b/.gitignore @@ -8,3 +8,5 @@ ecs-deploy-linux-amd64 webapp/node_modules/ webapp/dist/ vendor/ +bin/ +.vscode/ diff --git a/Gopkg.lock b/Gopkg.lock index 5f4abd5..702044d 100644 --- a/Gopkg.lock +++ b/Gopkg.lock @@ -26,7 +26,7 @@ version = "v2.3.1" [[projects]] - digest = "1:48934db4e5e41f30e83c5650b4e8d3fcc58a431987cdd36d9d019eecc45f2645" + digest = "1:1713df27130c2745660c1773b3101e63b0888d930089687be5b2f61b3f2fe8e9" name = "github.com/aws/aws-sdk-go" packages = [ "aws", @@ -38,6 +38,7 @@ "aws/credentials", "aws/credentials/ec2rolecreds", "aws/credentials/endpointcreds", + "aws/credentials/processcreds", "aws/credentials/stscreds", "aws/crr", "aws/csm", @@ -65,6 +66,7 @@ "service/autoscaling", "service/cloudwatch", "service/cloudwatchlogs", + "service/cognitoidentityprovider", "service/dynamodb", "service/dynamodb/dynamodbattribute", "service/dynamodb/dynamodbiface", @@ -78,8 +80,8 @@ "service/sts", ] pruneopts = "" - revision = "ceab7b7ac6f535d1397f124620720b5145f1ad59" - version = "v1.15.88" + revision = "ca6d5f771b63df153f02fe5b117434d5ebafeeda" + version = "v1.19.15" [[projects]] digest = "1:a2470e727142c0fb8dbf6f230fb63cf8dc39da614b55c157ad31d27fbc40e033" @@ -244,11 +246,11 @@ revision = "78f36f7876767b1899b4bac7381d109fe5a58539" [[projects]] - digest = "1:6f49eae0c1e5dab1dafafee34b207aeb7a42303105960944828c2079b92fc88e" + digest = "1:13fe471d0ed891e8544eddfeeb0471fd3c9f2015609a1c000aefdedf52a19d40" name = "github.com/jmespath/go-jmespath" packages = ["."] pruneopts = "" - revision = "0b12d6b5" + revision = "c2b33e84" [[projects]] digest = "1:302ad9379eb146668760df4d779a95379acab43ce5f9a28f27f3273f98232020" @@ -448,6 +450,7 @@ "github.com/aws/aws-sdk-go/service/autoscaling", "github.com/aws/aws-sdk-go/service/cloudwatch", "github.com/aws/aws-sdk-go/service/cloudwatchlogs", + "github.com/aws/aws-sdk-go/service/cognitoidentityprovider", "github.com/aws/aws-sdk-go/service/dynamodb", "github.com/aws/aws-sdk-go/service/ec2", "github.com/aws/aws-sdk-go/service/ecr", diff --git a/Gopkg.toml b/Gopkg.toml index c274e99..78905b1 100644 --- a/Gopkg.toml +++ b/Gopkg.toml @@ -27,7 +27,7 @@ [[constraint]] name = "github.com/aws/aws-sdk-go" - version = "1.15.88" + version = "1.19.15" [[constraint]] name = "github.com/crewjam/saml" diff --git a/Makefile b/Makefile index 81f9386..7181148 100644 --- a/Makefile +++ b/Makefile @@ -16,8 +16,13 @@ deps: build-server: GOOS=linux GOARCH=${GOARCH} go build ${LDFLAGS} -o ${SERVER_BINARY}-linux-${GOARCH} cmd/ecs-deploy/main.go +build-server-darwin: + GOOS=darwin GOARCH=${GOARCH} go build ${LDFLAGS} -o ${SERVER_BINARY}-linux-${GOARCH} cmd/ecs-deploy/main.go + build-client: GOOS=linux GOARCH=${GOARCH} go build ${LDFLAGS} -o ${CLIENT_BINARY}-linux-${GOARCH} cmd/ecs-client/main.go +build-client-darwin: + GOOS=darwin GOARCH=${GOARCH} go build ${LDFLAGS} -o ${CLIENT_BINARY}-linux-${GOARCH} cmd/ecs-client/main.go build-server-static: CGO_ENABLED=0 GOOS=linux GOARCH=${GOARCH} go build -a -installsuffix cgo ${LDFLAGS} -o ${SERVER_BINARY}-linux-${GOARCH} cmd/ecs-deploy/main.go diff --git a/api/api_test.go b/api/api_test.go index 7ffa9f5..8668326 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -2,6 +2,8 @@ package api import ( "testing" + + "github.com/in4it/ecs-deploy/service" ) func TestDeployServiceValidator(t *testing.T) { @@ -9,8 +11,8 @@ func TestDeployServiceValidator(t *testing.T) { a := API{} // test with 2 characters - d := Deploy{ - Containers: []*DeployContainer{ + d := service.Deploy{ + Containers: []*service.DeployContainer{ { ContainerName: "abc", }, @@ -23,8 +25,8 @@ func TestDeployServiceValidator(t *testing.T) { } // test with 3 characters - d = Deploy{ - Containers: []*DeployContainer{ + d = service.Deploy{ + Containers: []*service.DeployContainer{ { ContainerName: "abc", }, @@ -38,8 +40,8 @@ func TestDeployServiceValidator(t *testing.T) { // test with wrong container name serviceName = "myservice" - d = Deploy{ - Containers: []*DeployContainer{ + d = service.Deploy{ + Containers: []*service.DeployContainer{ { ContainerName: "ab", }, diff --git a/api/controller.go b/api/controller.go index 9fa2a62..c256414 100644 --- a/api/controller.go +++ b/api/controller.go @@ -140,7 +140,10 @@ func (c *Controller) Deploy(serviceName string, d service.Deploy) (*service.Depl if err != nil { return nil, err } - c.updateDeployment(d, ddLast, serviceName, taskDefArn, iamRoleArn) + err = c.updateDeployment(d, ddLast, serviceName, taskDefArn, iamRoleArn) + if err != nil { + controllerLogger.Errorf("updateDeployment failed: %s", err) + } } // Mark previous deployment as aborted if still running @@ -272,6 +275,21 @@ func (c *Controller) updateDeployment(d service.Deploy, ddLast *service.DynamoDe s.UpdateServiceListeners(s.ClusterName, s.ServiceName, listeners) // don't update ecs service later updateECSService = false + } else { + // check for rules changes + if c.rulesChanged(d, ddLast) { + controllerLogger.Infof("Recreating alb rules for: " + serviceName) + // recreate rules + err = c.deleteRulesForTarget(serviceName, d, targetGroupArn, alb) + if err != nil { + controllerLogger.Infof("Couldn't delete existing rules for target: " + serviceName) + } + // create new rules + _, err := c.createRulesForTarget(serviceName, d, targetGroupArn, alb) + if err != nil { + return err + } + } } } ps := ecs.Paramstore{} @@ -314,6 +332,29 @@ func (c *Controller) updateDeployment(d service.Deploy, ddLast *service.DynamoDe } return nil } + +func (c *Controller) rulesChanged(d service.Deploy, ddLast *service.DynamoDeployment) bool { + if len(d.RuleConditions) != len(ddLast.DeployData.RuleConditions) { + return true + } + + // sort rule conditions + sortedRuleCondition := d.RuleConditions + ddLastSortedRuleCondition := ddLast.DeployData.RuleConditions + sort.Sort(ruleConditionSort(sortedRuleCondition)) + sort.Sort(ruleConditionSort(ddLastSortedRuleCondition)) + // loop over rule conditions to compare them + for k, v := range sortedRuleCondition { + v2 := ddLastSortedRuleCondition[k] + if !cmp.Equal(v, v2) { + return true + } + } + + return false + +} + func (c *Controller) redeploy(serviceName, time string) (*service.DeployResult, error) { s := service.NewService() dd, err := s.GetDeployment(serviceName, time) @@ -431,13 +472,78 @@ func (c *Controller) deleteRulesForTarget(serviceName string, d service.Deploy, if err != nil { return err } - ruleArns := alb.GetRulesByTargetGroupArn(*targetGroupArn) - for _, ruleArn := range ruleArns { + ruleArnsToDelete := alb.GetRulesByTargetGroupArn(*targetGroupArn) + authRuleArns := alb.GetRuleByTargetGroupArnWithAuth(*targetGroupArn) + for _, authRuleArn := range authRuleArns { + conditionField, conditionValue := alb.GetConditionsForRule(authRuleArn) + controllerLogger.Debugf("deleteRulesForTarget: found authRule with conditionField %s and conditionValue %s", strings.Join(conditionField, ","), strings.Join(conditionValue, ",")) + httpListener := alb.GetListenerArnForProtocol("http") + if httpListener != "" { + ruleArn, _, err := alb.FindRule(httpListener, "", conditionField, conditionValue) + if err != nil { + controllerLogger.Debugf("deleteRulesForTarget: rule not found: %s", err) + } + if ruleArn != nil { + ruleArnsToDelete = append(ruleArnsToDelete, *ruleArn) + } + + } + } + for _, ruleArn := range ruleArnsToDelete { alb.DeleteRule(ruleArn) } return nil } +// delete rule for a targetgroup with specific listener +func (c *Controller) deleteRuleForTargetWithListener(serviceName string, r *service.DeployRuleConditions, targetGroupArn *string, alb *ecs.ALB, listener string) error { + _, conditionField, conditionValue := c.getALBConditionFieldAndValue(*r, alb.GetDomain()) + err := alb.GetRulesForAllListeners() + if err != nil { + return err + } + ruleArn, _, err := alb.FindRule(listener, *targetGroupArn, conditionField, conditionValue) + if err != nil { + return err + } + return alb.DeleteRule(*ruleArn) +} + +// Update rule for a specific targetGroups +func (c *Controller) UpdateRuleForTarget(serviceName string, r *service.DeployRuleConditions, rLast *service.DeployRuleConditions, targetGroupArn *string, alb *ecs.ALB, listener string) error { + _, conditionField, conditionValue := c.getALBConditionFieldAndValue(*rLast, alb.GetDomain()) + err := alb.GetRulesForAllListeners() + if err != nil { + return err + } + ruleArn, _, err := alb.FindRule(alb.GetListenerArnForProtocol(listener), *targetGroupArn, conditionField, conditionValue) + if err != nil { + return err + } + ruleType, _, conditionValue := c.getALBConditionFieldAndValue(*rLast, alb.GetDomain()) + + // if cognito is set, a redirect is needed instead (cognito doesn't work with http) + if strings.ToLower(listener) == "http" && r.CognitoAuth.ClientName != "" { + return alb.UpdateRuleToHTTPSRedirect(*targetGroupArn, *ruleArn, ruleType, conditionValue) + } + + return alb.UpdateRule(*targetGroupArn, *ruleArn, ruleType, conditionValue, r.CognitoAuth) + +} + +func (c *Controller) getALBConditionFieldAndValue(r service.DeployRuleConditions, domain string) (string, []string, []string) { + if r.PathPattern != "" && r.Hostname != "" { + return "combined", []string{"path-pattern", "host-header"}, []string{r.PathPattern, r.Hostname + "." + domain} + } + if r.PathPattern != "" { + return "pathPattern", []string{"path-pattern"}, []string{r.PathPattern} + } + if r.Hostname != "" { + return "hostname", []string{"host-header"}, []string{r.Hostname + "." + domain} + } + return "", []string{}, []string{} +} + // Deploy rules for a specific targetGroup func (c *Controller) createRulesForTarget(serviceName string, d service.Deploy, targetGroupArn *string, alb *ecs.ALB) ([]string, error) { var listeners []string @@ -450,32 +556,16 @@ func (c *Controller) createRulesForTarget(serviceName string, d service.Deploy, if len(d.RuleConditions) > 0 { // create rules based on conditions var newRules int - for _, r := range d.RuleConditions { - if r.PathPattern != "" && r.Hostname != "" { - rules := []string{r.PathPattern, r.Hostname} - l, err := alb.CreateRuleForListeners("combined", r.Listeners, *targetGroupArn, rules, (priority + 10 + int64(newRules))) - if err != nil { - return nil, err - } - newRules += len(r.Listeners) - listeners = append(listeners, l...) - } else if r.PathPattern != "" { - rules := []string{r.PathPattern} - l, err := alb.CreateRuleForListeners("pathPattern", r.Listeners, *targetGroupArn, rules, (priority + 10 + int64(newRules))) - if err != nil { - return nil, err - } - newRules += len(r.Listeners) - listeners = append(listeners, l...) - } else if r.Hostname != "" { - rules := []string{r.Hostname} - l, err := alb.CreateRuleForListeners("hostname", r.Listeners, *targetGroupArn, rules, (priority + 10 + int64(newRules))) - if err != nil { - return nil, err - } - newRules += len(r.Listeners) - listeners = append(listeners, l...) + ruleConditionsSorted := d.RuleConditions + sort.Sort(ruleConditionSort(ruleConditionsSorted)) + for _, r := range ruleConditionsSorted { + ruleType, _, conditionValue := c.getALBConditionFieldAndValue(*r, alb.GetDomain()) + l, err := alb.CreateRuleForListeners(ruleType, r.Listeners, *targetGroupArn, conditionValue, (priority + 10 + int64(newRules)), r.CognitoAuth) + if err != nil { + return nil, err } + newRules += len(r.Listeners) + listeners = append(listeners, l...) } } else { // create default rules ( /servicename path on all listeners ) diff --git a/api/sort.go b/api/sort.go new file mode 100644 index 0000000..49eb961 --- /dev/null +++ b/api/sort.go @@ -0,0 +1,17 @@ +package api + +import ( + "github.com/in4it/ecs-deploy/service" +) + +type ruleConditionSort []*service.DeployRuleConditions + +func (s ruleConditionSort) Len() int { + return len(s) +} +func (s ruleConditionSort) Swap(i, j int) { + s[i], s[j] = s[j], s[i] +} +func (s ruleConditionSort) Less(i, j int) bool { + return len(s[i].PathPattern) > len(s[j].PathPattern) +} diff --git a/api/sort_test.go b/api/sort_test.go new file mode 100644 index 0000000..5eb2891 --- /dev/null +++ b/api/sort_test.go @@ -0,0 +1,43 @@ +package api + +import ( + "sort" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/in4it/ecs-deploy/service" +) + +func TestRuleConditionSort(t *testing.T) { + conditions := []*service.DeployRuleConditions{ + { + Hostname: "test", + }, + { + Hostname: "test", + PathPattern: "/api", + }, + { + Hostname: "test", + PathPattern: "/api/v1", + }, + } + conditionsSorted := []*service.DeployRuleConditions{ + { + Hostname: "test", + PathPattern: "/api/v1", + }, + { + Hostname: "test", + PathPattern: "/api", + }, + { + Hostname: "test", + }, + } + sort.Sort(ruleConditionSort(conditions)) + + if !cmp.Equal(conditions, conditionsSorted) { + t.Errorf("Conditions is not correctly sorted") + } +} diff --git a/provider/ecs/alb.go b/provider/ecs/alb.go index a768a2c..fd5a5b3 100644 --- a/provider/ecs/alb.go +++ b/provider/ecs/alb.go @@ -340,7 +340,7 @@ func (a *ALB) GetHighestRule() (int64, error) { return 0, errors.New("Could not describe alb listener rules") } - albLogger.Debugf("Looping rules: %+v", result.Rules) + albLogger.Tracef("Looping rules: %+v", result.Rules) for _, rule := range result.Rules { if i, _ := strconv.ParseInt(*rule.Priority, 10, 64); i > highest { albLogger.Debugf("Found rule with priority: %d", i) @@ -364,7 +364,7 @@ func (a *ALB) GetHighestRule() (int64, error) { func (a *ALB) CreateRuleForAllListeners(ruleType string, targetGroupArn string, rules []string, priority int64) ([]string, error) { var listeners []string for _, l := range a.Listeners { - err := a.CreateRule(ruleType, *l.ListenerArn, targetGroupArn, rules, priority) + err := a.CreateRule(ruleType, *l.ListenerArn, targetGroupArn, rules, priority, service.DeployRuleConditionsCognitoAuth{}) if err != nil { return nil, err } @@ -373,61 +373,149 @@ func (a *ALB) CreateRuleForAllListeners(ruleType string, targetGroupArn string, return listeners, nil } -func (a *ALB) CreateRuleForListeners(ruleType string, listeners []string, targetGroupArn string, rules []string, priority int64) ([]string, error) { - var retListeners []string +func (a *ALB) CreateRuleForListeners(ruleType string, listeners []string, targetGroupArn string, rules []string, priority int64, cognitoAuth service.DeployRuleConditionsCognitoAuth) ([]string, error) { + retListeners := a.getListenersArnForProtocol(listeners) + for k, listener := range retListeners { + var err error + // if cognito is set, a redirect is needed instead (cognito doesn't work with http) + if strings.ToLower(listeners[k]) == "http" && cognitoAuth.ClientName != "" { + err = a.CreateHTTPSRedirectRule(ruleType, listener, targetGroupArn, rules, priority) + } else { + err = a.CreateRule(ruleType, listener, targetGroupArn, rules, priority, cognitoAuth) + } + if err != nil { + return nil, err + } + } + return retListeners, nil +} + +func (a *ALB) getListenersArnForProtocol(listeners []string) []string { + var listenersArn []string for _, l := range a.Listeners { for _, l2 := range listeners { - if l.Protocol != nil && strings.ToLower(*l.Protocol) == strings.ToLower(l2) { - err := a.CreateRule(ruleType, *l.ListenerArn, targetGroupArn, rules, priority) - if err != nil { - return nil, err - } - retListeners = append(retListeners, *l.ListenerArn) + if l.Protocol != nil && strings.ToLower(aws.StringValue(l.Protocol)) == strings.ToLower(l2) { + listenersArn = append(listenersArn, aws.StringValue(l.ListenerArn)) } } } - return retListeners, nil + albLogger.Debugf("getListenersArnForProtocol: resolved %s to %s", strings.Join(listeners, ","), strings.Join(listenersArn, ",")) + + return listenersArn +} + +/* + * Gets listeners ARN based on http / https string + */ +func (a *ALB) GetListenerArnForProtocol(listener string) string { + listeners := a.getListenersArnForProtocol([]string{listener}) + if len(listeners) == 1 { + return listeners[0] + } + return "" } -func (a *ALB) CreateRule(ruleType string, listenerArn string, targetGroupArn string, rules []string, priority int64) error { +/* + * modify an existing rule to a https redirect + */ +func (a *ALB) UpdateRuleToHTTPSRedirect(targetGroupArn, ruleArn string, ruleType string, rules []string) error { svc := elbv2.New(session.New()) - input := &elbv2.CreateRuleInput{ + input := &elbv2.ModifyRuleInput{ + Actions: []*elbv2.Action{ + { + RedirectConfig: &elbv2.RedirectActionConfig{ + Protocol: aws.String("HTTPS"), + StatusCode: aws.String("HTTP_301"), + Port: aws.String("443"), + }, + Type: aws.String("redirect"), + }, + }, + RuleArn: aws.String(ruleArn), + } + conditions, err := a.getRuleConditions(ruleType, rules) + if err != nil { + return err + } + input.SetConditions(conditions) + + _, err = svc.ModifyRule(input) + if err != nil { + if aerr, ok := err.(awserr.Error); ok { + albLogger.Errorf(aerr.Error()) + } else { + albLogger.Errorf(err.Error()) + } + return errors.New("Could not modify alb rule") + } + return nil +} + +func (a *ALB) UpdateRule(targetGroupArn, ruleArn string, ruleType string, rules []string, cognitoAuth service.DeployRuleConditionsCognitoAuth) error { + svc := elbv2.New(session.New()) + input := &elbv2.ModifyRuleInput{ Actions: []*elbv2.Action{ { TargetGroupArn: aws.String(targetGroupArn), Type: aws.String("forward"), }, }, - ListenerArn: aws.String(listenerArn), - Priority: aws.Int64(priority), + RuleArn: aws.String(ruleArn), + } + conditions, err := a.getRuleConditions(ruleType, rules) + if err != nil { + return err + } + input.SetConditions(conditions) + + // cognito + if cognitoAuth.UserPoolName != "" && cognitoAuth.ClientName != "" { + cognitoAction, err := a.getCognitoAction(targetGroupArn, cognitoAuth) + if err != nil { + return err + } + input.SetActions(cognitoAction) } + _, err = svc.ModifyRule(input) + if err != nil { + if aerr, ok := err.(awserr.Error); ok { + albLogger.Errorf(aerr.Error()) + } else { + albLogger.Errorf(err.Error()) + } + return errors.New("Could not modify alb rule") + } + return nil +} + +func (a *ALB) getRuleConditions(ruleType string, rules []string) ([]*elbv2.RuleCondition, error) { if ruleType == "pathPattern" { if len(rules) != 1 { - return errors.New("Wrong number of rules (expected 1, got " + strconv.Itoa(len(rules)) + ")") + return nil, errors.New("Wrong number of rules (expected 1, got " + strconv.Itoa(len(rules)) + ")") } - input.SetConditions([]*elbv2.RuleCondition{ + return []*elbv2.RuleCondition{ { Field: aws.String("path-pattern"), Values: []*string{aws.String(rules[0])}, }, - }) + }, nil } else if ruleType == "hostname" { if len(rules) != 1 { - return errors.New("Wrong number of rules (expected 1, got " + strconv.Itoa(len(rules)) + ")") + return nil, errors.New("Wrong number of rules (expected 1, got " + strconv.Itoa(len(rules)) + ")") } - hostname := rules[0] + "." + util.GetEnv("LOADBALANCER_DOMAIN", a.Domain) - input.SetConditions([]*elbv2.RuleCondition{ + hostname := rules[0] + return []*elbv2.RuleCondition{ { Field: aws.String("host-header"), Values: []*string{aws.String(hostname)}, }, - }) + }, nil } else if ruleType == "combined" { if len(rules) != 2 { - return errors.New("Wrong number of rules (expected 2, got " + strconv.Itoa(len(rules)) + ")") + return nil, errors.New("Wrong number of rules (expected 2, got " + strconv.Itoa(len(rules)) + ")") } - hostname := rules[1] + "." + util.GetEnv("LOADBALANCER_DOMAIN", a.Domain) - input.SetConditions([]*elbv2.RuleCondition{ + hostname := rules[1] + return []*elbv2.RuleCondition{ { Field: aws.String("path-pattern"), Values: []*string{aws.String(rules[0])}, @@ -436,43 +524,73 @@ func (a *ALB) CreateRule(ruleType string, listenerArn string, targetGroupArn str Field: aws.String("host-header"), Values: []*string{aws.String(hostname)}, }, - }) + }, nil + } else { - return errors.New("ruleType not recognized: " + ruleType) + return nil, errors.New("ruleType not recognized: " + ruleType) } +} - _, err := svc.CreateRule(input) +func (a *ALB) CreateHTTPSRedirectRule(ruleType string, listenerArn string, targetGroupArn string, rules []string, priority int64) error { + svc := elbv2.New(session.New()) + input := &elbv2.CreateRuleInput{ + Actions: []*elbv2.Action{ + { + RedirectConfig: &elbv2.RedirectActionConfig{ + Protocol: aws.String("HTTPS"), + StatusCode: aws.String("HTTP_301"), + Port: aws.String("443"), + }, + Type: aws.String("redirect"), + }, + }, + ListenerArn: aws.String(listenerArn), + Priority: aws.Int64(priority), + } + conditions, err := a.getRuleConditions(ruleType, rules) if err != nil { - if aerr, ok := err.(awserr.Error); ok { - switch aerr.Code() { - case elbv2.ErrCodePriorityInUseException: - albLogger.Errorf(elbv2.ErrCodePriorityInUseException+": %v", aerr.Error()) - case elbv2.ErrCodeTooManyTargetGroupsException: - albLogger.Errorf(elbv2.ErrCodeTooManyTargetGroupsException+": %v", aerr.Error()) - case elbv2.ErrCodeTooManyRulesException: - albLogger.Errorf(elbv2.ErrCodeTooManyRulesException+": %v", aerr.Error()) - case elbv2.ErrCodeTargetGroupAssociationLimitException: - albLogger.Errorf(elbv2.ErrCodeTargetGroupAssociationLimitException+": %v", aerr.Error()) - case elbv2.ErrCodeIncompatibleProtocolsException: - albLogger.Errorf(elbv2.ErrCodeIncompatibleProtocolsException+": %v", aerr.Error()) - case elbv2.ErrCodeListenerNotFoundException: - albLogger.Errorf(elbv2.ErrCodeListenerNotFoundException+": %v", aerr.Error()) - case elbv2.ErrCodeTargetGroupNotFoundException: - albLogger.Errorf(elbv2.ErrCodeTargetGroupNotFoundException+": %v", aerr.Error()) - case elbv2.ErrCodeInvalidConfigurationRequestException: - albLogger.Errorf(elbv2.ErrCodeInvalidConfigurationRequestException+": %v", aerr.Error()) - case elbv2.ErrCodeTooManyRegistrationsForTargetIdException: - albLogger.Errorf(elbv2.ErrCodeTooManyRegistrationsForTargetIdException+": %v", aerr.Error()) - case elbv2.ErrCodeTooManyTargetsException: - albLogger.Errorf(elbv2.ErrCodeTooManyTargetsException+": %v", aerr.Error()) - default: - albLogger.Errorf(aerr.Error()) - } - } else { - // Print the error, cast err to awserr.Error to get the Code and - // Message from an error. - albLogger.Errorf(err.Error()) + return err + } + input.SetConditions(conditions) + + _, err = svc.CreateRule(input) + if err != nil { + albLogger.Errorf(err.Error()) + return errors.New("Could not create alb rule") + } + return nil +} + +func (a *ALB) CreateRule(ruleType string, listenerArn string, targetGroupArn string, rules []string, priority int64, cognitoAuth service.DeployRuleConditionsCognitoAuth) error { + svc := elbv2.New(session.New()) + input := &elbv2.CreateRuleInput{ + Actions: []*elbv2.Action{ + { + TargetGroupArn: aws.String(targetGroupArn), + Type: aws.String("forward"), + }, + }, + ListenerArn: aws.String(listenerArn), + Priority: aws.Int64(priority), + } + conditions, err := a.getRuleConditions(ruleType, rules) + if err != nil { + return err + } + input.SetConditions(conditions) + + // cognito + if cognitoAuth.UserPoolName != "" && cognitoAuth.ClientName != "" { + cognitoAction, err := a.getCognitoAction(targetGroupArn, cognitoAuth) + if err != nil { + return err } + input.SetActions(cognitoAction) + } + + _, err = svc.CreateRule(input) + if err != nil { + albLogger.Errorf(err.Error()) return errors.New("Could not create alb rule") } return nil @@ -524,6 +642,47 @@ func (a *ALB) GetRulesByTargetGroupArn(targetGroupArn string) []string { } return result } +func (a *ALB) GetRuleByTargetGroupArnWithAuth(targetGroupArn string) []string { + var result []string + for _, rules := range a.Rules { + for _, rule := range rules { + foundAuthType := false + for _, ruleAction := range rule.Actions { + if aws.StringValue(ruleAction.Type) == "authenticate-cognito" { + foundAuthType = true + } + } + if foundAuthType { + for _, ruleAction := range rule.Actions { + if aws.StringValue(ruleAction.TargetGroupArn) == targetGroupArn { + result = append(result, aws.StringValue(rule.RuleArn)) + } + } + } + } + } + return result +} +func (a *ALB) GetConditionsForRule(ruleArn string) ([]string, []string) { + conditionFields := []string{} + conditionValues := []string{} + for _, rules := range a.Rules { + for _, rule := range rules { + if aws.StringValue(rule.RuleArn) == ruleArn { + for _, condition := range rule.Conditions { + if aws.StringValue(condition.Field) == "path-pattern" || aws.StringValue(condition.Field) == "host-header" { + conditionFields = append(conditionFields, aws.StringValue(condition.Field)) + if len(condition.Values) >= 1 { + conditionValues = append(conditionValues, aws.StringValue(condition.Values[0])) + } + } + } + } + } + } + return conditionFields, conditionValues +} + func (a *ALB) GetTargetGroupArn(serviceName string) (*string, error) { svc := elbv2.New(session.New()) input := &elbv2.DescribeTargetGroupsInput{ @@ -559,7 +718,13 @@ func (a *ALB) GetTargetGroupArn(serviceName string) (*string, error) { func (a *ALB) GetDomain() string { return util.GetEnv("LOADBALANCER_DOMAIN", a.Domain) } + +/* + * FindRule tries to find a matching rule in the Rules map + */ func (a *ALB) FindRule(listener string, targetGroupArn string, conditionField []string, conditionValue []string) (*string, *string, error) { + albLogger.Debugf("Find Rule: listener %s, targetGroupArn %s, conditionField %s, conditionValue %s", listener, targetGroupArn, strings.Join(conditionField, ","), strings.Join(conditionValue, ",")) + if len(conditionField) != len(conditionValue) { return nil, nil, errors.New("conditionField length not equal to conditionValue length") } @@ -567,25 +732,19 @@ func (a *ALB) FindRule(listener string, targetGroupArn string, conditionField [] if rules, ok := a.Rules[listener]; ok { for _, r := range rules { for _, a := range r.Actions { - if *a.Type == "forward" && *a.TargetGroupArn == targetGroupArn { - // target group found, loop over conditions - priorityFound := false - skip := false + if (aws.StringValue(a.Type) == "forward" && aws.StringValue(a.TargetGroupArn) == targetGroupArn) || aws.StringValue(a.Type) == "redirect" { + // possible action match found, checking conditions + matchingConditions := []bool{} for _, c := range r.Conditions { match := false - for i, _ := range conditionField { - if *c.Field == conditionField[i] && len(c.Values) > 0 && *c.Values[0] == conditionValue[i] { + for i := range conditionField { + if aws.StringValue(c.Field) == conditionField[i] && len(c.Values) > 0 && aws.StringValue(c.Values[0]) == conditionValue[i] { match = true } } - if !skip && match { // if any condition was false, skip this rule - priorityFound = true - } else { - priorityFound = false - skip = true - } + matchingConditions = append(matchingConditions, match) } - if priorityFound { + if len(matchingConditions) == len(conditionField) && util.IsBoolArrayTrue(matchingConditions) { return r.RuleArn, r.Priority, nil } } @@ -693,3 +852,28 @@ func (a *ALB) DeleteRule(ruleArn string) error { } return nil } +func (a *ALB) getCognitoAction(targetGroupArn string, cognitoAuth service.DeployRuleConditionsCognitoAuth) ([]*elbv2.Action, error) { + // get cognito user pool info + cognito := CognitoIdp{} + userPoolArn, userPoolClientID, userPoolDomain, err := cognito.getUserPoolInfo(cognitoAuth.UserPoolName, cognitoAuth.ClientName) + if err != nil { + return nil, err + } + return []*elbv2.Action{ + { + AuthenticateCognitoConfig: &elbv2.AuthenticateCognitoActionConfig{ + OnUnauthenticatedRequest: aws.String("deny"), + UserPoolArn: aws.String(userPoolArn), + UserPoolClientId: aws.String(userPoolClientID), + UserPoolDomain: aws.String(userPoolDomain), + }, + Type: aws.String("authenticate-cognito"), + Order: aws.Int64(1), + }, + { + TargetGroupArn: aws.String(targetGroupArn), + Type: aws.String("forward"), + Order: aws.Int64(2), + }, + }, nil +} diff --git a/provider/ecs/alb_test.go b/provider/ecs/alb_test.go index c66f545..bdf7332 100644 --- a/provider/ecs/alb_test.go +++ b/provider/ecs/alb_test.go @@ -4,6 +4,8 @@ import ( "fmt" "testing" + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/service/elbv2" "github.com/in4it/ecs-deploy/util" ) @@ -22,3 +24,104 @@ func TestGetHighestRule(t *testing.T) { } fmt.Printf("Highest rule in ALB (%v) is: %d ", a.loadBalancerName, highest) } + +func TestFindRule(t *testing.T) { + a := ALB{} + a.Rules = make(map[string][]*elbv2.Rule) + a.Rules["listener"] = []*elbv2.Rule{ + { + RuleArn: aws.String("1"), + Priority: aws.String("1"), + Actions: []*elbv2.Action{ + { + Type: aws.String("forward"), + TargetGroupArn: aws.String("targetGroup"), + }, + }, + Conditions: []*elbv2.RuleCondition{ + { + Field: aws.String("host-header"), + Values: []*string{aws.String("host.example.com")}, + }, + }, + }, + { + RuleArn: aws.String("2"), + Priority: aws.String("2"), + Actions: []*elbv2.Action{ + { + Type: aws.String("forward"), + TargetGroupArn: aws.String("targetGroup"), + }, + }, + Conditions: []*elbv2.RuleCondition{ + { + Field: aws.String("host-header"), + Values: []*string{aws.String("host-2.example.com")}, + }, + { + Field: aws.String("path-pattern"), + Values: []*string{aws.String("/api")}, + }, + }, + }, + { + RuleArn: aws.String("3"), + Priority: aws.String("3"), + Actions: []*elbv2.Action{ + { + Type: aws.String("forward"), + TargetGroupArn: aws.String("targetGroup"), + }, + }, + Conditions: []*elbv2.RuleCondition{ + { + Field: aws.String("host-header"), + Values: []*string{aws.String("host.example.com")}, + }, + { + Field: aws.String("path-pattern"), + Values: []*string{aws.String("/api/v1")}, + }, + }, + }, + { + RuleArn: aws.String("4"), + Priority: aws.String("4"), + Actions: []*elbv2.Action{ + { + Type: aws.String("forward"), + TargetGroupArn: aws.String("targetGroup"), + }, + }, + Conditions: []*elbv2.RuleCondition{ + { + Field: aws.String("host-header"), + Values: []*string{aws.String("host.example.com")}, + }, + { + Field: aws.String("path-pattern"), + Values: []*string{aws.String("/api")}, + }, + }, + }, + } + conditionField := []string{"host-header", "path-pattern"} + conditionValue := []string{"host.example.com", "/api"} + ruleArn, priority, err := a.FindRule("listener", "targetGroup", conditionField, conditionValue) + if err != nil { + t.Errorf("Error: %v", err) + } + if *priority != "4" || *ruleArn != "4" { + t.Errorf("Error: found wrong rule") + } + // re-order + a.Rules["listener"][0], a.Rules["listener"][3] = a.Rules["listener"][3], a.Rules["listener"][0] + ruleArn, priority, err = a.FindRule("listener", "targetGroup", conditionField, conditionValue) + if err != nil { + t.Errorf("Error: %v", err) + } + if *priority != "4" || *ruleArn != "4" { + t.Errorf("Error: found wrong rule") + } +} diff --git a/provider/ecs/cognito-idp.go b/provider/ecs/cognito-idp.go new file mode 100644 index 0000000..bb68ebb --- /dev/null +++ b/provider/ecs/cognito-idp.go @@ -0,0 +1,113 @@ +package ecs + +import ( + "fmt" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/awserr" + "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/cognitoidentityprovider" + "github.com/juju/loggo" +) + +// logging +var cognitoLogger = loggo.GetLogger("cognito") + +// Cognito struct +type CognitoIdp struct { +} + +func (c *CognitoIdp) getUserPoolInfo(userPoolName, userPoolClientName string) (string, string, string, error) { + userPoolID, err := c.getUserPoolArn(userPoolName) + if err != nil { + return "", "", "", err + } + + userPool, err := c.describeUserPool(userPoolID) + if err != nil { + return "", "", "", err + } + + userPoolClientID, err := c.getUserPoolClientID(userPoolID, userPoolClientName) + if err != nil { + return "", "", "", err + } + + return aws.StringValue(userPool.Arn), userPoolClientID, aws.StringValue(userPool.Domain), nil +} + +func (c *CognitoIdp) describeUserPool(userPoolID string) (*cognitoidentityprovider.UserPoolType, error) { + svc := cognitoidentityprovider.New(session.New()) + input := &cognitoidentityprovider.DescribeUserPoolInput{ + UserPoolId: aws.String(userPoolID), + } + + res, err := svc.DescribeUserPool(input) + if err != nil { + return nil, err + } + + return res.UserPool, nil +} + +func (c *CognitoIdp) getUserPoolArn(userPoolName string) (string, error) { + svc := cognitoidentityprovider.New(session.New()) + input := &cognitoidentityprovider.ListUserPoolsInput{ + MaxResults: aws.Int64(60), + } + + userPoolID := "" + + pageNum := 0 + err := svc.ListUserPoolsPages(input, + func(page *cognitoidentityprovider.ListUserPoolsOutput, lastPage bool) bool { + pageNum++ + for _, userPool := range page.UserPools { + if aws.StringValue(userPool.Name) == userPoolName { + userPoolID = aws.StringValue(userPool.Id) + } + } + return pageNum <= 100 + }) + + if err != nil { + if aerr, ok := err.(awserr.Error); ok { + cognitoLogger.Errorf(aerr.Error()) + } else { + cognitoLogger.Errorf(err.Error()) + } + return userPoolID, err + } + if userPoolID == "" { + return userPoolID, fmt.Errorf("Could not find userpool with name %s", userPoolName) + } + return userPoolID, nil +} +func (c *CognitoIdp) getUserPoolClientID(userPoolID, userPoolClientName string) (string, error) { + svc := cognitoidentityprovider.New(session.New()) + input := &cognitoidentityprovider.ListUserPoolClientsInput{ + UserPoolId: aws.String(userPoolID), + } + + userPoolClientNameID := "" + + pageNum := 0 + err := svc.ListUserPoolClientsPages(input, + func(page *cognitoidentityprovider.ListUserPoolClientsOutput, lastPage bool) bool { + pageNum++ + for _, userPoolClient := range page.UserPoolClients { + if aws.StringValue(userPoolClient.ClientName) == userPoolClientName { + userPoolClientNameID = aws.StringValue(userPoolClient.ClientId) + } + } + return pageNum <= 100 + }) + + if err != nil { + return userPoolClientNameID, err + } + if userPoolClientNameID == "" { + return userPoolClientNameID, fmt.Errorf("Could not find userpool client with name %s", userPoolClientName) + } + return userPoolClientNameID, nil +} diff --git a/service/deploy.go b/service/deploy.go index 2f1dc18..e098515 100644 --- a/service/deploy.go +++ b/service/deploy.go @@ -92,9 +92,14 @@ type DeployHealthCheck struct { GracePeriodSeconds int64 `json:"gracePeriodSeconds" yaml:"gracePeriodSeconds"` } type DeployRuleConditions struct { - Listeners []string `json:"listeners" yaml:"listeners"` - PathPattern string `json:"pathPattern" yaml:"pathPattern"` - Hostname string `json:"hostname" yaml:"hostname"` + Listeners []string `json:"listeners" yaml:"listeners"` + PathPattern string `json:"pathPattern" yaml:"pathPattern"` + Hostname string `json:"hostname" yaml:"hostname"` + CognitoAuth DeployRuleConditionsCognitoAuth `json:"cognitoAuth" yaml:"cognitoAuth"` +} +type DeployRuleConditionsCognitoAuth struct { + UserPoolName string `json:"userPoolName" yaml:"userPoolName"` + ClientName string `json:"clientName" yaml:"clientName"` } type DeployStickiness struct { Enabled bool `json:"enabled" yaml:"enabled"` diff --git a/util/common.go b/util/common.go index 1a4a60f..f45fb08 100644 --- a/util/common.go +++ b/util/common.go @@ -77,3 +77,59 @@ func TruncateString(str string, n int) string { return str } } + +/* + * RemoveCommonElements removes the common elements and returns the first array + */ +func RemoveCommonElements(a, b []string) []string { + var c = []string{} + var d = []string{} + + for _, item1 := range a { + for _, item2 := range b { + if item1 == item2 { + c = append(c, item2) + } + } + } + + for _, item1 := range a { + found := false + for _, item2 := range c { + if item1 == item2 { + found = true + } + } + if !found { + d = append(d, item1) + } + } + return d +} + +/* + * IsBoolArrayTrue checks whether the array contains only true elements + */ +func IsBoolArrayTrue(array []bool) bool { + if len(array) == 0 { + return false + } + for _, v := range array { + if !v { + return false + } + } + return true +} + +/* + * InArray returns true if the value exists in the array + */ +func InArray(a []string, v string) (ret bool, i int) { + for i = range a { + if ret = a[i] == v; ret { + return ret, i + } + } + return false, -1 +} diff --git a/util/common_test.go b/util/common_test.go new file mode 100644 index 0000000..4424c1b --- /dev/null +++ b/util/common_test.go @@ -0,0 +1,28 @@ +package util + +import "testing" + +func TestFindRule(t *testing.T) { + a := []string{"1", "2", "3"} + b := []string{"1", "2", "4"} + c := RemoveCommonElements(a, b) + d := RemoveCommonElements(b, a) + if c[0] != "3" { + t.Errorf("Unexpected result: %s (wanted %s)", c[0], "3") + } + if d[0] != "4" { + t.Errorf("Unexpected result: %s (wanted %s)", c[0], "3") + } +} + +func TestInArray(t *testing.T) { + a := []string{"1", "2", "3"} + c, i := InArray(a, "2") + d, y := InArray(a, "4") + if c == false || i != 1 { + t.Errorf("Unexpected result: %v (wanted %v)", c, true) + } + if d == true || y != -1 { + t.Errorf("Unexpected result: %v (wanted %v)", d, false) + } +}