Skip to content

Commit

Permalink
Refactor storage to properly support dynamic configuration; to pick u…
Browse files Browse the repository at this point in the history
…p changes of env-vars or files, we need to recreate the storage-clients on every interval (#15)
  • Loading branch information
Argelbargel authored Sep 23, 2023
1 parent b13c544 commit 772e852
Show file tree
Hide file tree
Showing 13 changed files with 426 additions and 278 deletions.
7 changes: 1 addition & 6 deletions internal/agent/snapshot-agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,12 +107,7 @@ func (a *SnapshotAgent) reconfigure(ctx context.Context, config SnapshotAgentCon
return err
}

manager, err := storage.CreateManager(ctx, config.Snapshots.Storages)
if err != nil {
return err
}

a.update(ctx, client, manager, config.Snapshots.StorageConfigDefaults)
a.update(ctx, client, storage.CreateManager(config.Snapshots.Storages), config.Snapshots.StorageConfigDefaults)
return nil
}

Expand Down
87 changes: 49 additions & 38 deletions internal/agent/snapshot-agent_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,10 @@ func TestTakeSnapshotUploadsSnapshot(t *testing.T) {
Frequency: time.Millisecond,
}

controller := &storageControllerStub{
nextSnapshot: time.Now().Add(time.Millisecond * 250),
}
factory := &storageControllerFactoryStub{nextSnapshot: time.Now().Add(time.Millisecond * 250)}

manager := &storage.Manager{}
manager.AddStorage(controller)
manager.AddStorageFactory(factory)

ctx := context.Background()

Expand All @@ -39,10 +37,10 @@ func TestTakeSnapshotUploadsSnapshot(t *testing.T) {
<-timer.C

assert.True(t, clientVaultAPI.tookSnapshot)
assert.Equal(t, clientVaultAPI.snapshotData, controller.uploadData)
assert.Equal(t, defaults, controller.defaults)
assert.WithinRange(t, controller.snapshotTimestamp, start, start.Add(50*time.Millisecond))
assert.GreaterOrEqual(t, time.Now(), controller.nextSnapshot)
assert.Equal(t, clientVaultAPI.snapshotData, factory.uploadData)
assert.Equal(t, defaults, factory.defaults)
assert.WithinRange(t, factory.snapshotTimestamp, start, start.Add(50*time.Millisecond))
assert.GreaterOrEqual(t, time.Now(), factory.nextSnapshot)
}

func TestTakeSnapshotLocksTakeSnapshot(t *testing.T) {
Expand Down Expand Up @@ -119,12 +117,12 @@ func TestTakeSnapshotFailsWhenTempFileCannotBeCreated(t *testing.T) {
Frequency: time.Millisecond * 150,
}

controller := &storageControllerStub{
factory := &storageControllerFactoryStub{
nextSnapshot: time.Now().Add(defaults.Frequency * 4),
}

manager := &storage.Manager{}
manager.AddStorage(controller)
manager.AddStorageFactory(factory)

ctx := context.Background()

Expand All @@ -135,7 +133,7 @@ func TestTakeSnapshotFailsWhenTempFileCannotBeCreated(t *testing.T) {
<-timer.C

assert.False(t, clientVaultAPI.tookSnapshot)
assert.Less(t, time.Now(), controller.nextSnapshot.Add(-defaults.Frequency))
assert.Less(t, time.Now(), factory.nextSnapshot.Add(-defaults.Frequency))
}

func TestTakeSnapshotFailsWhenSnapshottingFails(t *testing.T) {
Expand All @@ -148,12 +146,12 @@ func TestTakeSnapshotFailsWhenSnapshottingFails(t *testing.T) {
Frequency: time.Millisecond * 150,
}

controller := &storageControllerStub{
factory := &storageControllerFactoryStub{
nextSnapshot: time.Now().Add(defaults.Frequency * 4),
}

manager := &storage.Manager{}
manager.AddStorage(controller)
manager.AddStorageFactory(factory)

ctx := context.Background()

Expand All @@ -164,7 +162,7 @@ func TestTakeSnapshotFailsWhenSnapshottingFails(t *testing.T) {
<-timer.C

assert.True(t, clientVaultAPI.tookSnapshot)
assert.Less(t, time.Now(), controller.nextSnapshot.Add(-defaults.Frequency))
assert.Less(t, time.Now(), factory.nextSnapshot.Add(-defaults.Frequency))
}

func TestTakeSnapshotIgnoresEmptySnapshot(t *testing.T) {
Expand All @@ -176,12 +174,12 @@ func TestTakeSnapshotIgnoresEmptySnapshot(t *testing.T) {
Frequency: time.Millisecond * 150,
}

controller := &storageControllerStub{
factory := &storageControllerFactoryStub{
nextSnapshot: time.Now().Add(defaults.Frequency * 4),
}

manager := &storage.Manager{}
manager.AddStorage(controller)
manager.AddStorageFactory(factory)

ctx := context.Background()

Expand All @@ -192,7 +190,7 @@ func TestTakeSnapshotIgnoresEmptySnapshot(t *testing.T) {
<-timer.C

assert.True(t, clientVaultAPI.tookSnapshot)
assert.Less(t, time.Now(), controller.nextSnapshot.Add(-defaults.Frequency))
assert.Less(t, time.Now(), factory.nextSnapshot.Add(-defaults.Frequency))
}

func TestIgnoresZeroTimeForScheduling(t *testing.T) {
Expand All @@ -205,12 +203,12 @@ func TestIgnoresZeroTimeForScheduling(t *testing.T) {
Frequency: time.Millisecond * 150,
}

controller := &storageControllerStub{
factory := &storageControllerFactoryStub{
nextSnapshot: time.Time{},
}

manager := &storage.Manager{}
manager.AddStorage(controller)
manager.AddStorageFactory(factory)

ctx := context.Background()

Expand All @@ -222,7 +220,7 @@ func TestIgnoresZeroTimeForScheduling(t *testing.T) {
<-timer.C

assert.True(t, clientVaultAPI.tookSnapshot)
assert.Equal(t, clientVaultAPI.snapshotData, controller.uploadData)
assert.Equal(t, clientVaultAPI.snapshotData, factory.uploadData)
assert.GreaterOrEqual(t, time.Now(), start.Add(defaults.Frequency))
}

Expand All @@ -233,24 +231,29 @@ func TestUpdateReschedulesSnapshots(t *testing.T) {
}

manager := &storage.Manager{}
manager.AddStorage(&storageControllerStub{nextSnapshot: time.Now().Add(time.Millisecond * 250)})
factory := &storageControllerFactoryStub{nextSnapshot: time.Now().Add(time.Millisecond * 250)}
manager.AddStorageFactory(factory)

newController := &storageControllerStub{nextSnapshot: time.Now().Add(time.Millisecond * 500)}
newFactory := &storageControllerFactoryStub{nextSnapshot: time.Now().Add(time.Millisecond * 500)}
newManager := &storage.Manager{}
newManager.AddStorage(newController)
newManager.AddStorageFactory(newFactory)

ctx := context.Background()
agent := newSnapshotAgent(t.TempDir())
agent.update(ctx, newClient(clientVaultAPI), manager, storage.StorageConfigDefaults{})
client := newClient(clientVaultAPI)
agent.update(ctx, client, manager, storage.StorageConfigDefaults{})
timer := agent.TakeSnapshot(ctx)

updated := make(chan bool, 1)
go func() {
agent.update(ctx, newClient(clientVaultAPI), newManager, storage.StorageConfigDefaults{})
agent.update(ctx, client, newManager, storage.StorageConfigDefaults{})
updated <- true
}()

<-updated
<-timer.C

assert.GreaterOrEqual(t, time.Now(), newController.nextSnapshot)
assert.GreaterOrEqual(t, time.Now(), newFactory.nextSnapshot)
assert.Equal(t, newManager, agent.manager)
}

Expand Down Expand Up @@ -304,36 +307,44 @@ func (stub clientVaultAPIAuthStub) Login(_ context.Context, _ any) (time.Duratio
return 0, nil
}

type storageControllerStub struct {
type storageControllerFactoryStub struct {
defaults storage.StorageConfigDefaults
uploadData string
uploadFails bool
snapshotTimestamp time.Time
nextSnapshot time.Time
}

func (stub *storageControllerStub) Destination() string {
func (stub *storageControllerFactoryStub) Destination() string {
return ""
}

func (stub *storageControllerStub) ScheduleSnapshot(_ context.Context, _ time.Time, _ storage.StorageConfigDefaults) time.Time {
return stub.nextSnapshot
func (stub *storageControllerFactoryStub) CreateController(context.Context) (storage.StorageController, error) {
return storageControllerStub{stub}, nil
}

type storageControllerStub struct {
factory *storageControllerFactoryStub
}

func (stub storageControllerStub) ScheduleSnapshot(_ context.Context, _ time.Time, _ storage.StorageConfigDefaults) (time.Time, error) {
return stub.factory.nextSnapshot, nil
}

func (stub *storageControllerStub) DeleteObsoleteSnapshots(_ context.Context, _ storage.StorageConfigDefaults) (int, error) {
func (stub storageControllerStub) DeleteObsoleteSnapshots(_ context.Context, _ storage.StorageConfigDefaults) (int, error) {
return 0, nil
}

func (stub *storageControllerStub) UploadSnapshot(_ context.Context, snapshot io.Reader, timestamp time.Time, defaults storage.StorageConfigDefaults) (bool, time.Time, error) {
stub.snapshotTimestamp = timestamp
stub.defaults = defaults
if stub.uploadFails {
return false, stub.nextSnapshot, errors.New("upload failed")
func (stub storageControllerStub) UploadSnapshot(_ context.Context, snapshot io.Reader, timestamp time.Time, defaults storage.StorageConfigDefaults) (bool, time.Time, error) {
stub.factory.snapshotTimestamp = timestamp
stub.factory.defaults = defaults
if stub.factory.uploadFails {
return false, stub.factory.nextSnapshot, errors.New("upload failed")
}
data, err := io.ReadAll(snapshot)
if err != nil {
return false, time.Now(), err
}
stub.uploadData = string(data)
return true, stub.nextSnapshot, nil
stub.factory.uploadData = string(data)
return true, stub.factory.nextSnapshot, nil
}
64 changes: 34 additions & 30 deletions internal/agent/storage/aws.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,46 +37,50 @@ type awsStorageImpl struct {
sse bool
}

func createAWSStorageController(ctx context.Context, config AWSStorageConfig) (*storageControllerImpl[types.Object], error) {
func (conf AWSStorageConfig) Destination() string {
return fmt.Sprintf("aws s3 bucket %s at %s", conf.Bucket, conf.Endpoint)
}

func (conf AWSStorageConfig) CreateController(ctx context.Context) (StorageController, error) {
keyPrefix := ""
if config.KeyPrefix != "" {
keyPrefix = fmt.Sprintf("%s/", config.KeyPrefix)
if conf.KeyPrefix != "" {
keyPrefix = fmt.Sprintf("%s/", conf.KeyPrefix)
}

client, err := createS3Client(ctx, config)
client, err := conf.createClient(ctx)
if err != nil {
return nil, nil
return nil, err
}

return newStorageController[types.Object](
config.storageConfig,
fmt.Sprintf("aws s3 bucket %s at %s", config.Bucket, config.Endpoint),
conf.storageConfig,
awsStorageImpl{
client: client,
keyPrefix: keyPrefix,
bucket: config.Bucket,
sse: config.UseServerSideEncryption,
bucket: conf.Bucket,
sse: conf.UseServerSideEncryption,
},
), nil

}

func createS3Client(ctx context.Context, config AWSStorageConfig) (*s3.Client, error) {
accessKeyId, err := config.AccessKeyId.Resolve(false)
func (conf AWSStorageConfig) createClient(ctx context.Context) (*s3.Client, error) {
accessKeyId, err := conf.AccessKeyId.Resolve(false)
if err != nil {
return nil, err
}

accessKey, err := config.AccessKey.Resolve(accessKeyId != "")
accessKey, err := conf.AccessKey.Resolve(accessKeyId != "")
if err != nil {
return nil, err
}

sessionToken, err := config.SessionToken.Resolve(false)
sessionToken, err := conf.SessionToken.Resolve(false)
if err != nil {
return nil, err
}

region, err := config.Region.Resolve(false)
region, err := conf.Region.Resolve(false)
if err != nil {
return nil, err
}
Expand All @@ -90,14 +94,14 @@ func createS3Client(ctx context.Context, config AWSStorageConfig) (*s3.Client, e
clientConfig.Credentials = credentials.NewStaticCredentialsProvider(accessKeyId, accessKey, sessionToken)
}

endpoint, err := config.Endpoint.Resolve(false)
endpoint, err := conf.Endpoint.Resolve(false)
if err != nil {
return nil, err
}

client := s3.NewFromConfig(clientConfig, func(o *s3.Options) {
o.UsePathStyle = config.ForcePathStyle
if config.Endpoint != "" {
o.UsePathStyle = conf.ForcePathStyle
if conf.Endpoint != "" {
o.BaseEndpoint = aws.String(endpoint)
}
})
Expand All @@ -107,18 +111,18 @@ func createS3Client(ctx context.Context, config AWSStorageConfig) (*s3.Client, e

// nolint:unused
// implements interface storage
func (u awsStorageImpl) UploadSnapshot(ctx context.Context, name string, data io.Reader) error {
func (s awsStorageImpl) uploadSnapshot(ctx context.Context, name string, data io.Reader) error {
input := &s3.PutObjectInput{
Bucket: &u.bucket,
Key: aws.String(u.keyPrefix + name),
Bucket: &s.bucket,
Key: aws.String(s.keyPrefix + name),
Body: data,
}

if u.sse {
if s.sse {
input.ServerSideEncryption = types.ServerSideEncryptionAes256
}

uploader := manager.NewUploader(u.client)
uploader := manager.NewUploader(s.client)
if _, err := uploader.Upload(ctx, input); err != nil {
return err
}
Expand All @@ -128,13 +132,13 @@ func (u awsStorageImpl) UploadSnapshot(ctx context.Context, name string, data io

// nolint:unused
// implements interface storage
func (u awsStorageImpl) DeleteSnapshot(ctx context.Context, snapshot types.Object) error {
func (s awsStorageImpl) deleteSnapshot(ctx context.Context, snapshot types.Object) error {
input := &s3.DeleteObjectInput{
Bucket: &u.bucket,
Bucket: &s.bucket,
Key: snapshot.Key,
}

if _, err := u.client.DeleteObject(ctx, input); err != nil {
if _, err := s.client.DeleteObject(ctx, input); err != nil {
return err
}

Expand All @@ -143,12 +147,12 @@ func (u awsStorageImpl) DeleteSnapshot(ctx context.Context, snapshot types.Objec

// nolint:unused
// implements interface storage
func (u awsStorageImpl) ListSnapshots(ctx context.Context, prefix string, ext string) ([]types.Object, error) {
func (s awsStorageImpl) listSnapshots(ctx context.Context, prefix string, ext string) ([]types.Object, error) {
var result []types.Object

existingSnapshotList, err := u.client.ListObjectsV2(ctx, &s3.ListObjectsV2Input{
Bucket: &u.bucket,
Prefix: aws.String(u.keyPrefix),
existingSnapshotList, err := s.client.ListObjectsV2(ctx, &s3.ListObjectsV2Input{
Bucket: &s.bucket,
Prefix: aws.String(s.keyPrefix),
})

if err != nil {
Expand All @@ -166,6 +170,6 @@ func (u awsStorageImpl) ListSnapshots(ctx context.Context, prefix string, ext st

// nolint:unused
// implements interface storage
func (u awsStorageImpl) GetLastModifiedTime(snapshot types.Object) time.Time {
func (s awsStorageImpl) getLastModifiedTime(snapshot types.Object) time.Time {
return *snapshot.LastModified
}
Loading

0 comments on commit 772e852

Please sign in to comment.