Skip to content

Commit

Permalink
Merge pull request Cloud-Foundations#39 from rgooch/refactor-dnslb-dr…
Browse files Browse the repository at this point in the history
…ivers

Refactor DNSLB drivers.
  • Loading branch information
rgooch authored Apr 11, 2021
2 parents 3301410 + 37739af commit 2db936f
Show file tree
Hide file tree
Showing 9 changed files with 182 additions and 98 deletions.
4 changes: 2 additions & 2 deletions pkg/loadbalancing/dnslb/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ type Params struct {
// RecordReadWriter implements a DNS record reader and writer. It is used to
// plugin the underlying DNS provider.
type RecordReadWriter interface {
ReadRecord(fqdn string) ([]string, error)
WriteRecord(fqdn string, ips []string, ttl time.Duration) error
ReadRecords(fqdn, recType string) ([]string, time.Duration, error)
WriteRecords(fqdn, recType string, recs []string, ttl time.Duration) error
}

// RegionFilter implements the Filter method, which is used to restrict DNS
Expand Down
4 changes: 3 additions & 1 deletion pkg/loadbalancing/dnslb/config/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@ import (
)

type Config struct {
AllRegions bool `yaml:"all_regions"`
AllRegions bool `yaml:"all_regions"`
AwsAssumeRoleArn string `yaml:"aws_assume_role_arn"`
AwsProfile string `yaml:"aws_profile"`
dnslb.Config `yaml:",inline"`
Preserve bool `yaml:"preserve"`
Route53HostedZoneId string `yaml:"route53_hosted_zone_id"`
Expand Down
102 changes: 102 additions & 0 deletions pkg/loadbalancing/dnslb/config/aws.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
package config

import (
"errors"
"fmt"
"time"

"github.com/Cloud-Foundations/golib/pkg/awsutil/metadata"
"github.com/Cloud-Foundations/golib/pkg/loadbalancing/dnslb"
"github.com/Cloud-Foundations/golib/pkg/loadbalancing/dnslb/ec2"
"github.com/Cloud-Foundations/golib/pkg/loadbalancing/dnslb/route53"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials/stscreds"
"github.com/aws/aws-sdk-go/aws/session"
)

func awsConfigure(config *Config, params *dnslb.Params, region string) error {
if config.CheckInterval < 1 {
config.CheckInterval = time.Minute
}
awsSession, err := awsCreateSession(config)
if err != nil {
return err
}
if err := awsEC2Configure(awsSession, config, params, region); err != nil {
return err
}
if err := awsRoute53Configure(awsSession, config, params); err != nil {
return err
}
return nil
}

func awsCreateSession(config *Config) (*session.Session, error) {
var awsSession *session.Session
var err error
if config.AwsProfile == "" {
awsSession, err = session.NewSession(&aws.Config{})
} else {
awsSession, err = session.NewSessionWithOptions(session.Options{
Profile: config.AwsProfile,
})
}
if err != nil {
return nil, fmt.Errorf("error creating session: %s", err)
}
if awsSession == nil {
return nil, errors.New("awsSession == nil")
}
if config.AwsAssumeRoleArn == "" {
return awsSession, nil
}
creds := stscreds.NewCredentials(awsSession, config.AwsAssumeRoleArn)
assumedSession, err := session.NewSession(&aws.Config{Credentials: creds})
if err != nil {
return nil, fmt.Errorf("error creating assumed role session: %s", err)
}
if assumedSession == nil {
return nil, errors.New("assumedSession == nil")
}
return assumedSession, nil
}

func awsEC2Configure(awsSession *session.Session, config *Config,
params *dnslb.Params, region string) error {
if config.AllRegions {
if !config.Preserve {
return errors.New("cannot destroy instances in other regions")
}
return nil
}
if region == "" {
metadataClient, err := metadata.GetMetadataClient()
if err != nil {
return err
}
region, err = metadataClient.Region()
if err != nil {
return err
}
}
instanceHandler, err := ec2.New(awsSession, region, params.Logger)
if err != nil {
return err
}
params.RegionFilter = instanceHandler
if !config.Preserve {
params.Destroyer = instanceHandler
}
return nil
}

func awsRoute53Configure(awsSession *session.Session, config *Config,
params *dnslb.Params) error {
var err error
params.RecordReadWriter, err = route53.New(awsSession,
config.Route53HostedZoneId, params.Logger)
if err != nil {
return err
}
return nil
}
59 changes: 17 additions & 42 deletions pkg/loadbalancing/dnslb/config/impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,73 +2,48 @@ package config

import (
"errors"
"time"

"github.com/Cloud-Foundations/golib/pkg/awsutil/metadata"
"github.com/Cloud-Foundations/golib/pkg/loadbalancing/dnslb"
"github.com/Cloud-Foundations/golib/pkg/loadbalancing/dnslb/ec2"
"github.com/Cloud-Foundations/golib/pkg/loadbalancing/dnslb/route53"
"github.com/Cloud-Foundations/golib/pkg/log"
)

type dnsConfigureFunc func(config *Config, params *dnslb.Params) error

func awsRoute53Configure(config *Config, params *dnslb.Params) error {
if config.CheckInterval < 1 {
config.CheckInterval = time.Minute
}
var err error
params.RecordReadWriter, err = route53.New(config.Route53HostedZoneId,
params.Logger)
if err != nil {
return err
}
if config.AllRegions {
if !config.Preserve {
return errors.New("cannot destroy instances in other regions")
}
return nil
}
metadataClient, err := metadata.GetMetadataClient()
if err != nil {
return err
}
instanceHandler, err := ec2.New(metadataClient, params.Logger)
if err != nil {
return err
}
params.RegionFilter = instanceHandler
if !config.Preserve {
params.Destroyer = instanceHandler
}
return nil
}
type dnsConfigureFunc func(config *Config, params *dnslb.Params,
region string) error

func getDnsConfigureFuncs(config Config) ([]dnsConfigureFunc, error) {
funcs := make([]dnsConfigureFunc, 0)
if config.Route53HostedZoneId != "" {
funcs = append(funcs, awsRoute53Configure)
funcs = append(funcs, awsConfigure)
}
if len(funcs) > 1 {
return nil, errors.New("multiple DNS providers specified")
}
return funcs, nil
}

func newLoadBalancer(config Config,
logger log.DebugLogger) (*dnslb.LoadBalancer, error) {
funcs, err := getDnsConfigureFuncs(config)
func makeDnslbParams(config *Config, region string, logger log.DebugLogger) (
*dnslb.Params, error) {
funcs, err := getDnsConfigureFuncs(*config)
if err != nil {
return nil, err
}
if len(funcs) < 1 {
return nil, errors.New("no DNS zone provider specified")
}
params := dnslb.Params{Logger: logger}
if err := funcs[0](&config, &params); err != nil {
if err := funcs[0](config, &params, region); err != nil {
return nil, err
}
return &params, nil
}

func newLoadBalancer(config Config,
logger log.DebugLogger) (*dnslb.LoadBalancer, error) {
params, err := makeDnslbParams(&config, "", logger)
if err != nil {
return nil, err
}
return dnslb.New(config.Config, params)
return dnslb.New(config.Config, *params)
}

func (c Config) check() (bool, error) {
Expand Down
6 changes: 3 additions & 3 deletions pkg/loadbalancing/dnslb/ec2/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import (
"sync"

"github.com/Cloud-Foundations/golib/pkg/log"
"github.com/aws/aws-sdk-go/aws/ec2metadata"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/ec2"
)

Expand All @@ -15,9 +15,9 @@ type InstanceHandler struct {
ipToInstance map[string]*string // Key: IP, value instance ID.
}

func New(metadataClient *ec2metadata.EC2Metadata,
func New(awsSession *session.Session, region string,
logger log.DebugLogger) (*InstanceHandler, error) {
return newInstanceHandler(metadataClient, logger)
return newInstanceHandler(awsSession, region, logger)
}

func (h *InstanceHandler) Destroy(ips map[string]struct{}) error {
Expand Down
18 changes: 2 additions & 16 deletions pkg/loadbalancing/dnslb/ec2/impl.go
Original file line number Diff line number Diff line change
@@ -1,33 +1,19 @@
package ec2

import (
"errors"
"fmt"

"github.com/Cloud-Foundations/golib/pkg/log"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/ec2metadata"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/ec2"
)

type ipMap map[string]struct{}

func newInstanceHandler(metadataClient *ec2metadata.EC2Metadata,
func newInstanceHandler(awsSession *session.Session, region string,
logger log.DebugLogger) (*InstanceHandler, error) {
region, err := metadataClient.Region()
if err != nil {
return nil, err
}
awsSession, err := session.NewSession(&aws.Config{
Region: aws.String(region),
})
if err != nil {
return nil, fmt.Errorf("error creating session: %s", err)
}
if awsSession == nil {
return nil, errors.New("awsSession == nil")
}
awsSession = awsSession.Copy(&aws.Config{Region: aws.String(region)})
return &InstanceHandler{
awsService: ec2.New(awsSession),
logger: logger,
Expand Down
6 changes: 3 additions & 3 deletions pkg/loadbalancing/dnslb/impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ func (lb *LoadBalancer) checkLoop() {
}

func (lb *LoadBalancer) check() error {
checkList, err := lb.p.RecordReadWriter.ReadRecord(lb.config.FQDN)
checkList, _, err := lb.p.RecordReadWriter.ReadRecords(lb.config.FQDN, "A")
if err != nil {
return err
}
Expand Down Expand Up @@ -167,7 +167,7 @@ func (lb *LoadBalancer) check() error {
if err := lb.destroy(removeMap); err != nil {
return err
}
oldList, err := lb.p.RecordReadWriter.ReadRecord(lb.config.FQDN)
oldList, _, err := lb.p.RecordReadWriter.ReadRecords(lb.config.FQDN, "A")
if err != nil {
return err
}
Expand Down Expand Up @@ -204,7 +204,7 @@ func (lb *LoadBalancer) check() error {
return nil
}
lb.p.Logger.Printf("updating DNS for: %s: %v\n", lb.config.FQDN, newList)
return lb.p.RecordReadWriter.WriteRecord(lb.config.FQDN, newList,
return lb.p.RecordReadWriter.WriteRecords(lb.config.FQDN, "A", newList,
lb.config.CheckInterval)
}

Expand Down
16 changes: 9 additions & 7 deletions pkg/loadbalancing/dnslb/route53/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"time"

"github.com/Cloud-Foundations/golib/pkg/log"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/route53"
)

Expand All @@ -19,16 +20,17 @@ type RecordReadWriter struct {

// New creates a *RecordReadWriter.
// The logger is used for logging messages.
func New(hostedZoneId string,
func New(awsSession *session.Session, hostedZoneId string,
logger log.DebugLogger) (*RecordReadWriter, error) {
return newRecordReadWriter(hostedZoneId, logger)
return newRecordReadWriter(awsSession, hostedZoneId, logger)
}

func (rrw *RecordReadWriter) ReadRecord(fqdn string) ([]string, error) {
return rrw.readRecord(fqdn)
func (rrw *RecordReadWriter) ReadRecords(fqdn, recType string) (
[]string, time.Duration, error) {
return rrw.readRecords(fqdn, recType)
}

func (rrw *RecordReadWriter) WriteRecord(fqdn string, ips []string,
ttl time.Duration) error {
return rrw.writeRecord(fqdn, ips, ttl)
func (rrw *RecordReadWriter) WriteRecords(fqdn, recType string,
records []string, ttl time.Duration) error {
return rrw.writeRecords(fqdn, recType, records, ttl)
}
Loading

0 comments on commit 2db936f

Please sign in to comment.