Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor storage to properly support dynamic configuration; to pick u… #15

Merged
merged 1 commit into from
Sep 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 1 addition & 6 deletions internal/agent/snapshot-agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,12 +107,7 @@ func (a *SnapshotAgent) reconfigure(ctx context.Context, config SnapshotAgentCon
return err
}

manager, err := storage.CreateManager(ctx, config.Snapshots.Storages)
if err != nil {
return err
}

a.update(ctx, client, manager, config.Snapshots.StorageConfigDefaults)
a.update(ctx, client, storage.CreateManager(config.Snapshots.Storages), config.Snapshots.StorageConfigDefaults)
return nil
}

Expand Down
87 changes: 49 additions & 38 deletions internal/agent/snapshot-agent_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,10 @@ func TestTakeSnapshotUploadsSnapshot(t *testing.T) {
Frequency: time.Millisecond,
}

controller := &storageControllerStub{
nextSnapshot: time.Now().Add(time.Millisecond * 250),
}
factory := &storageControllerFactoryStub{nextSnapshot: time.Now().Add(time.Millisecond * 250)}

manager := &storage.Manager{}
manager.AddStorage(controller)
manager.AddStorageFactory(factory)

ctx := context.Background()

Expand All @@ -39,10 +37,10 @@ func TestTakeSnapshotUploadsSnapshot(t *testing.T) {
<-timer.C

assert.True(t, clientVaultAPI.tookSnapshot)
assert.Equal(t, clientVaultAPI.snapshotData, controller.uploadData)
assert.Equal(t, defaults, controller.defaults)
assert.WithinRange(t, controller.snapshotTimestamp, start, start.Add(50*time.Millisecond))
assert.GreaterOrEqual(t, time.Now(), controller.nextSnapshot)
assert.Equal(t, clientVaultAPI.snapshotData, factory.uploadData)
assert.Equal(t, defaults, factory.defaults)
assert.WithinRange(t, factory.snapshotTimestamp, start, start.Add(50*time.Millisecond))
assert.GreaterOrEqual(t, time.Now(), factory.nextSnapshot)
}

func TestTakeSnapshotLocksTakeSnapshot(t *testing.T) {
Expand Down Expand Up @@ -119,12 +117,12 @@ func TestTakeSnapshotFailsWhenTempFileCannotBeCreated(t *testing.T) {
Frequency: time.Millisecond * 150,
}

controller := &storageControllerStub{
factory := &storageControllerFactoryStub{
nextSnapshot: time.Now().Add(defaults.Frequency * 4),
}

manager := &storage.Manager{}
manager.AddStorage(controller)
manager.AddStorageFactory(factory)

ctx := context.Background()

Expand All @@ -135,7 +133,7 @@ func TestTakeSnapshotFailsWhenTempFileCannotBeCreated(t *testing.T) {
<-timer.C

assert.False(t, clientVaultAPI.tookSnapshot)
assert.Less(t, time.Now(), controller.nextSnapshot.Add(-defaults.Frequency))
assert.Less(t, time.Now(), factory.nextSnapshot.Add(-defaults.Frequency))
}

func TestTakeSnapshotFailsWhenSnapshottingFails(t *testing.T) {
Expand All @@ -148,12 +146,12 @@ func TestTakeSnapshotFailsWhenSnapshottingFails(t *testing.T) {
Frequency: time.Millisecond * 150,
}

controller := &storageControllerStub{
factory := &storageControllerFactoryStub{
nextSnapshot: time.Now().Add(defaults.Frequency * 4),
}

manager := &storage.Manager{}
manager.AddStorage(controller)
manager.AddStorageFactory(factory)

ctx := context.Background()

Expand All @@ -164,7 +162,7 @@ func TestTakeSnapshotFailsWhenSnapshottingFails(t *testing.T) {
<-timer.C

assert.True(t, clientVaultAPI.tookSnapshot)
assert.Less(t, time.Now(), controller.nextSnapshot.Add(-defaults.Frequency))
assert.Less(t, time.Now(), factory.nextSnapshot.Add(-defaults.Frequency))
}

func TestTakeSnapshotIgnoresEmptySnapshot(t *testing.T) {
Expand All @@ -176,12 +174,12 @@ func TestTakeSnapshotIgnoresEmptySnapshot(t *testing.T) {
Frequency: time.Millisecond * 150,
}

controller := &storageControllerStub{
factory := &storageControllerFactoryStub{
nextSnapshot: time.Now().Add(defaults.Frequency * 4),
}

manager := &storage.Manager{}
manager.AddStorage(controller)
manager.AddStorageFactory(factory)

ctx := context.Background()

Expand All @@ -192,7 +190,7 @@ func TestTakeSnapshotIgnoresEmptySnapshot(t *testing.T) {
<-timer.C

assert.True(t, clientVaultAPI.tookSnapshot)
assert.Less(t, time.Now(), controller.nextSnapshot.Add(-defaults.Frequency))
assert.Less(t, time.Now(), factory.nextSnapshot.Add(-defaults.Frequency))
}

func TestIgnoresZeroTimeForScheduling(t *testing.T) {
Expand All @@ -205,12 +203,12 @@ func TestIgnoresZeroTimeForScheduling(t *testing.T) {
Frequency: time.Millisecond * 150,
}

controller := &storageControllerStub{
factory := &storageControllerFactoryStub{
nextSnapshot: time.Time{},
}

manager := &storage.Manager{}
manager.AddStorage(controller)
manager.AddStorageFactory(factory)

ctx := context.Background()

Expand All @@ -222,7 +220,7 @@ func TestIgnoresZeroTimeForScheduling(t *testing.T) {
<-timer.C

assert.True(t, clientVaultAPI.tookSnapshot)
assert.Equal(t, clientVaultAPI.snapshotData, controller.uploadData)
assert.Equal(t, clientVaultAPI.snapshotData, factory.uploadData)
assert.GreaterOrEqual(t, time.Now(), start.Add(defaults.Frequency))
}

Expand All @@ -233,24 +231,29 @@ func TestUpdateReschedulesSnapshots(t *testing.T) {
}

manager := &storage.Manager{}
manager.AddStorage(&storageControllerStub{nextSnapshot: time.Now().Add(time.Millisecond * 250)})
factory := &storageControllerFactoryStub{nextSnapshot: time.Now().Add(time.Millisecond * 250)}
manager.AddStorageFactory(factory)

newController := &storageControllerStub{nextSnapshot: time.Now().Add(time.Millisecond * 500)}
newFactory := &storageControllerFactoryStub{nextSnapshot: time.Now().Add(time.Millisecond * 500)}
newManager := &storage.Manager{}
newManager.AddStorage(newController)
newManager.AddStorageFactory(newFactory)

ctx := context.Background()
agent := newSnapshotAgent(t.TempDir())
agent.update(ctx, newClient(clientVaultAPI), manager, storage.StorageConfigDefaults{})
client := newClient(clientVaultAPI)
agent.update(ctx, client, manager, storage.StorageConfigDefaults{})
timer := agent.TakeSnapshot(ctx)

updated := make(chan bool, 1)
go func() {
agent.update(ctx, newClient(clientVaultAPI), newManager, storage.StorageConfigDefaults{})
agent.update(ctx, client, newManager, storage.StorageConfigDefaults{})
updated <- true
}()

<-updated
<-timer.C

assert.GreaterOrEqual(t, time.Now(), newController.nextSnapshot)
assert.GreaterOrEqual(t, time.Now(), newFactory.nextSnapshot)
assert.Equal(t, newManager, agent.manager)
}

Expand Down Expand Up @@ -304,36 +307,44 @@ func (stub clientVaultAPIAuthStub) Login(_ context.Context, _ any) (time.Duratio
return 0, nil
}

type storageControllerStub struct {
type storageControllerFactoryStub struct {
defaults storage.StorageConfigDefaults
uploadData string
uploadFails bool
snapshotTimestamp time.Time
nextSnapshot time.Time
}

func (stub *storageControllerStub) Destination() string {
func (stub *storageControllerFactoryStub) Destination() string {
return ""
}

func (stub *storageControllerStub) ScheduleSnapshot(_ context.Context, _ time.Time, _ storage.StorageConfigDefaults) time.Time {
return stub.nextSnapshot
func (stub *storageControllerFactoryStub) CreateController(context.Context) (storage.StorageController, error) {
return storageControllerStub{stub}, nil
}

type storageControllerStub struct {
factory *storageControllerFactoryStub
}

func (stub storageControllerStub) ScheduleSnapshot(_ context.Context, _ time.Time, _ storage.StorageConfigDefaults) (time.Time, error) {
return stub.factory.nextSnapshot, nil
}

func (stub *storageControllerStub) DeleteObsoleteSnapshots(_ context.Context, _ storage.StorageConfigDefaults) (int, error) {
func (stub storageControllerStub) DeleteObsoleteSnapshots(_ context.Context, _ storage.StorageConfigDefaults) (int, error) {
return 0, nil
}

func (stub *storageControllerStub) UploadSnapshot(_ context.Context, snapshot io.Reader, timestamp time.Time, defaults storage.StorageConfigDefaults) (bool, time.Time, error) {
stub.snapshotTimestamp = timestamp
stub.defaults = defaults
if stub.uploadFails {
return false, stub.nextSnapshot, errors.New("upload failed")
func (stub storageControllerStub) UploadSnapshot(_ context.Context, snapshot io.Reader, timestamp time.Time, defaults storage.StorageConfigDefaults) (bool, time.Time, error) {
stub.factory.snapshotTimestamp = timestamp
stub.factory.defaults = defaults
if stub.factory.uploadFails {
return false, stub.factory.nextSnapshot, errors.New("upload failed")
}
data, err := io.ReadAll(snapshot)
if err != nil {
return false, time.Now(), err
}
stub.uploadData = string(data)
return true, stub.nextSnapshot, nil
stub.factory.uploadData = string(data)
return true, stub.factory.nextSnapshot, nil
}
64 changes: 34 additions & 30 deletions internal/agent/storage/aws.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,46 +37,50 @@ type awsStorageImpl struct {
sse bool
}

func createAWSStorageController(ctx context.Context, config AWSStorageConfig) (*storageControllerImpl[types.Object], error) {
func (conf AWSStorageConfig) Destination() string {
return fmt.Sprintf("aws s3 bucket %s at %s", conf.Bucket, conf.Endpoint)
}

func (conf AWSStorageConfig) CreateController(ctx context.Context) (StorageController, error) {
keyPrefix := ""
if config.KeyPrefix != "" {
keyPrefix = fmt.Sprintf("%s/", config.KeyPrefix)
if conf.KeyPrefix != "" {
keyPrefix = fmt.Sprintf("%s/", conf.KeyPrefix)
}

client, err := createS3Client(ctx, config)
client, err := conf.createClient(ctx)
if err != nil {
return nil, nil
return nil, err
}

return newStorageController[types.Object](
config.storageConfig,
fmt.Sprintf("aws s3 bucket %s at %s", config.Bucket, config.Endpoint),
conf.storageConfig,
awsStorageImpl{
client: client,
keyPrefix: keyPrefix,
bucket: config.Bucket,
sse: config.UseServerSideEncryption,
bucket: conf.Bucket,
sse: conf.UseServerSideEncryption,
},
), nil

}

func createS3Client(ctx context.Context, config AWSStorageConfig) (*s3.Client, error) {
accessKeyId, err := config.AccessKeyId.Resolve(false)
func (conf AWSStorageConfig) createClient(ctx context.Context) (*s3.Client, error) {
accessKeyId, err := conf.AccessKeyId.Resolve(false)
if err != nil {
return nil, err
}

accessKey, err := config.AccessKey.Resolve(accessKeyId != "")
accessKey, err := conf.AccessKey.Resolve(accessKeyId != "")
if err != nil {
return nil, err
}

sessionToken, err := config.SessionToken.Resolve(false)
sessionToken, err := conf.SessionToken.Resolve(false)
if err != nil {
return nil, err
}

region, err := config.Region.Resolve(false)
region, err := conf.Region.Resolve(false)
if err != nil {
return nil, err
}
Expand All @@ -90,14 +94,14 @@ func createS3Client(ctx context.Context, config AWSStorageConfig) (*s3.Client, e
clientConfig.Credentials = credentials.NewStaticCredentialsProvider(accessKeyId, accessKey, sessionToken)
}

endpoint, err := config.Endpoint.Resolve(false)
endpoint, err := conf.Endpoint.Resolve(false)
if err != nil {
return nil, err
}

client := s3.NewFromConfig(clientConfig, func(o *s3.Options) {
o.UsePathStyle = config.ForcePathStyle
if config.Endpoint != "" {
o.UsePathStyle = conf.ForcePathStyle
if conf.Endpoint != "" {
o.BaseEndpoint = aws.String(endpoint)
}
})
Expand All @@ -107,18 +111,18 @@ func createS3Client(ctx context.Context, config AWSStorageConfig) (*s3.Client, e

// nolint:unused
// implements interface storage
func (u awsStorageImpl) UploadSnapshot(ctx context.Context, name string, data io.Reader) error {
func (s awsStorageImpl) uploadSnapshot(ctx context.Context, name string, data io.Reader) error {
input := &s3.PutObjectInput{
Bucket: &u.bucket,
Key: aws.String(u.keyPrefix + name),
Bucket: &s.bucket,
Key: aws.String(s.keyPrefix + name),
Body: data,
}

if u.sse {
if s.sse {
input.ServerSideEncryption = types.ServerSideEncryptionAes256
}

uploader := manager.NewUploader(u.client)
uploader := manager.NewUploader(s.client)
if _, err := uploader.Upload(ctx, input); err != nil {
return err
}
Expand All @@ -128,13 +132,13 @@ func (u awsStorageImpl) UploadSnapshot(ctx context.Context, name string, data io

// nolint:unused
// implements interface storage
func (u awsStorageImpl) DeleteSnapshot(ctx context.Context, snapshot types.Object) error {
func (s awsStorageImpl) deleteSnapshot(ctx context.Context, snapshot types.Object) error {
input := &s3.DeleteObjectInput{
Bucket: &u.bucket,
Bucket: &s.bucket,
Key: snapshot.Key,
}

if _, err := u.client.DeleteObject(ctx, input); err != nil {
if _, err := s.client.DeleteObject(ctx, input); err != nil {
return err
}

Expand All @@ -143,12 +147,12 @@ func (u awsStorageImpl) DeleteSnapshot(ctx context.Context, snapshot types.Objec

// nolint:unused
// implements interface storage
func (u awsStorageImpl) ListSnapshots(ctx context.Context, prefix string, ext string) ([]types.Object, error) {
func (s awsStorageImpl) listSnapshots(ctx context.Context, prefix string, ext string) ([]types.Object, error) {
var result []types.Object

existingSnapshotList, err := u.client.ListObjectsV2(ctx, &s3.ListObjectsV2Input{
Bucket: &u.bucket,
Prefix: aws.String(u.keyPrefix),
existingSnapshotList, err := s.client.ListObjectsV2(ctx, &s3.ListObjectsV2Input{
Bucket: &s.bucket,
Prefix: aws.String(s.keyPrefix),
})

if err != nil {
Expand All @@ -166,6 +170,6 @@ func (u awsStorageImpl) ListSnapshots(ctx context.Context, prefix string, ext st

// nolint:unused
// implements interface storage
func (u awsStorageImpl) GetLastModifiedTime(snapshot types.Object) time.Time {
func (s awsStorageImpl) getLastModifiedTime(snapshot types.Object) time.Time {
return *snapshot.LastModified
}
Loading
Loading