diff --git a/lib/backend/clone/clone.go b/lib/backend/clone/clone.go index 0043735a45cb6..6b2a793ff4565 100644 --- a/lib/backend/clone/clone.go +++ b/lib/backend/clone/clone.go @@ -2,6 +2,7 @@ package clone import ( "context" + "fmt" "log/slog" "sync/atomic" "time" @@ -25,6 +26,7 @@ type Cloner struct { src backend.Backend dst backend.Backend parallel int + force bool migrated atomic.Int64 log *slog.Logger } @@ -40,6 +42,9 @@ type Config struct { Destination backend.Config `yaml:"dst"` // Parallel is the number of items that will be cloned in parallel. Parallel int `yaml:"parallel"` + // Force indicates whether to clone data regardless of whether data already + // exists in the destination [backend.Backend]. + Force bool `yaml:"force"` // Log logs the progress of cloning. Log *slog.Logger } @@ -58,6 +63,7 @@ func New(ctx context.Context, config Config) (*Cloner, error) { src: src, dst: dst, parallel: config.Parallel, + force: config.Force, log: config.Log, } if cloner.parallel <= 0 { @@ -95,6 +101,18 @@ func (c *Cloner) Clone(ctx context.Context) error { ctx, cancel := context.WithCancel(ctx) defer cancel() + if !c.force { + result, err := c.dst.GetRange(ctx, start, backend.RangeEnd(start), 1) + if err != nil { + return trace.Wrap(err, "failed to check destination for existing data") + } + if len(result.Items) > 0 { + return trace.Errorf("unable to clone data to destination with existing data; this may be overriden by configuring 'force: true'") + } + } else { + c.log.Warn("Skipping check for existing data in destination.") + } + group, ctx := errgroup.WithContext(ctx) // Add 1 to ensure a goroutine exists for getting items. group.SetLimit(c.parallel + 1) @@ -130,7 +148,7 @@ func (c *Cloner) Clone(ctx context.Context) error { }) logProgress := func() { - c.log.Info("Migrated %d", c.migrated.Load()) + c.log.Info(fmt.Sprintf("Migrated %d", c.migrated.Load())) } defer logProgress() go func() { diff --git a/lib/backend/clone/clone_test.go b/lib/backend/clone/clone_test.go index 8cf96c463fa60..574fe4636f14b 100644 --- a/lib/backend/clone/clone_test.go +++ b/lib/backend/clone/clone_test.go @@ -14,7 +14,7 @@ import ( logutils "github.com/gravitational/teleport/lib/utils/log" ) -func TestMigration(t *testing.T) { +func TestClone(t *testing.T) { ctx := context.Background() src, err := memory.New(memory.Config{}) require.NoError(t, err) @@ -54,3 +54,51 @@ func TestMigration(t *testing.T) { require.Equal(t, itemCount, int(cloner.migrated.Load())) require.NoError(t, cloner.Close()) } + +func TestCloneForce(t *testing.T) { + ctx := context.Background() + src, err := memory.New(memory.Config{}) + require.NoError(t, err) + + dst, err := memory.New(memory.Config{}) + require.NoError(t, err) + + itemCount := 100 + items := make([]backend.Item, itemCount) + + for i := 0; i < itemCount; i++ { + item := backend.Item{ + Key: backend.Key(fmt.Sprintf("key-%05d", i)), + Value: []byte(fmt.Sprintf("value-%d", i)), + } + _, err := src.Put(ctx, item) + require.NoError(t, err) + items[i] = item + } + + _, err = dst.Put(ctx, items[0]) + require.NoError(t, err) + + cloner := Cloner{ + src: src, + dst: dst, + parallel: 10, + log: logutils.NewPackageLogger(), + } + + err = cloner.Clone(ctx) + require.Error(t, err) + + cloner.force = true + err = cloner.Clone(ctx) + require.NoError(t, err) + + start := backend.Key("") + result, err := dst.GetRange(ctx, start, backend.RangeEnd(start), 0) + require.NoError(t, err) + + diff := cmp.Diff(items, result.Items, cmpopts.IgnoreFields(backend.Item{}, "Revision", "ID")) + require.Empty(t, diff) + require.Equal(t, itemCount, int(cloner.migrated.Load())) + require.NoError(t, cloner.Close()) +}