diff --git a/cmd/raft-tester/main.go b/cmd/raft-tester/main.go index c21e3d758..be2581b0d 100644 --- a/cmd/raft-tester/main.go +++ b/cmd/raft-tester/main.go @@ -94,7 +94,7 @@ func main() { } }) wg.Go(func() error { - return cluster.Start(ctx, nil) + return cluster.Start(ctx) }) wg.Go(func() error { ticker := time.NewTicker(10 * time.Second) diff --git a/internal/raft/cluster.go b/internal/raft/cluster.go index a9d28781c..ee85ccae5 100644 --- a/internal/raft/cluster.go +++ b/internal/raft/cluster.go @@ -11,7 +11,6 @@ import ( "github.com/lni/dragonboat/v4/client" "github.com/lni/dragonboat/v4/config" "github.com/lni/dragonboat/v4/statemachine" - "github.com/lni/goutils/syncutil" ) type RaftConfig struct { @@ -110,8 +109,17 @@ func AddShard[Q any, R any, E Event, EPtr Unmasrshallable[E]]( } } -// Start the cluster. -func (c *Cluster) Start(ctx context.Context, ready chan struct{}) error { +// Start the cluster. Blocks until the cluster instance is ready. +func (c *Cluster) Start(ctx context.Context) error { + return c.start(ctx, false) +} + +// Join the cluster as a new member. Blocks until the cluster instance is ready. +func (c *Cluster) Join(ctx context.Context) error { + return c.start(ctx, true) +} + +func (c *Cluster) start(ctx context.Context, join bool) error { // Create node host config nhc := config.NodeHostConfig{ WALDir: c.config.DataDir, @@ -142,22 +150,18 @@ func (c *Cluster) Start(ctx context.Context, ready chan struct{}) error { } peers := make(map[uint64]string) - for idx, peer := range c.config.InitialMembers { - peers[uint64(idx+1)] = peer + if !join { + for idx, peer := range c.config.InitialMembers { + peers[uint64(idx+1)] = peer + } } // Start the raft node for this shard - if err := nh.StartReplica(peers, false, sm, cfg); err != nil { + if err := nh.StartReplica(peers, join, sm, cfg); err != nil { return fmt.Errorf("failed to start replica for shard %d: %w", shardID, err) } } - raftstopper := syncutil.NewStopper() - raftstopper.RunWorker(func() { - <-ctx.Done() - c.nh.Close() - }) - // Wait for all shards to be ready // TODO: WaitReady in the config should do this, but for some reason it doesn't work. for shardID := range c.shards { @@ -166,17 +170,19 @@ func (c *Cluster) Start(ctx context.Context, ready chan struct{}) error { } } - if ready != nil { - ready <- struct{}{} - } - raftstopper.Wait() + return nil +} - for shardID := range c.shards { - if err := c.nh.StopReplica(shardID, c.config.ReplicaID); err != nil { - return fmt.Errorf("failed to stop replica for shard %d: %w", shardID, err) - } - } +func (c *Cluster) Stop() { + c.nh.Close() +} +// AddMember to the cluster. This needs to be called on an existing running cluster member, +// before the new member is started. +func (c *Cluster) AddMember(ctx context.Context, shardID uint64, replicaID uint64, address string) error { + if err := c.nh.SyncRequestAddReplica(ctx, shardID, replicaID, address, 0); err != nil { + return fmt.Errorf("failed to add member: %w", err) + } return nil } diff --git a/internal/raft/cluster_test.go b/internal/raft/cluster_test.go index 652836ebd..27a431e18 100644 --- a/internal/raft/cluster_test.go +++ b/internal/raft/cluster_test.go @@ -9,6 +9,7 @@ import ( "github.com/alecthomas/assert/v2" "github.com/block/ftl/internal/raft" + "golang.org/x/sync/errgroup" ) type IntEvent int64 @@ -44,46 +45,80 @@ func TestCluster(t *testing.T) { members := []string{"localhost:51001", "localhost:51002"} - cluster1 := testCluster(t, members, 1) + cluster1 := testCluster(t, members, 1, members[0]) shard1_1 := raft.AddShard(ctx, cluster1, 1, &IntStateMachine{}) shard1_2 := raft.AddShard(ctx, cluster1, 2, &IntStateMachine{}) - cluster2 := testCluster(t, members, 2) + cluster2 := testCluster(t, members, 2, members[1]) shard2_1 := raft.AddShard(ctx, cluster2, 1, &IntStateMachine{}) shard2_2 := raft.AddShard(ctx, cluster2, 2, &IntStateMachine{}) - ready := make(chan struct{}) - go cluster1.Start(ctx, ready) //nolint:errcheck - go cluster2.Start(ctx, ready) //nolint:errcheck - <-ready - <-ready + wg, wctx := errgroup.WithContext(ctx) + wg.Go(func() error { return cluster1.Start(wctx) }) + wg.Go(func() error { return cluster2.Start(wctx) }) + assert.NoError(t, wg.Wait()) + defer cluster1.Stop() + defer cluster2.Stop() assert.NoError(t, shard1_1.Propose(ctx, IntEvent(1))) assert.NoError(t, shard2_1.Propose(ctx, IntEvent(2))) + assert.NoError(t, shard1_2.Propose(ctx, IntEvent(1))) assert.NoError(t, shard2_2.Propose(ctx, IntEvent(1))) - res, err := shard1_1.Query(ctx, 0) - assert.NoError(t, err) - assert.Equal(t, res, int64(3)) + assertShardValue(ctx, t, 3, shard1_1, shard2_1) + assertShardValue(ctx, t, 2, shard1_2, shard2_2) +} + +func TestJoiningExistingCluster(t *testing.T) { + ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(30*time.Second)) + defer cancel() + + members := []string{"localhost:51001", "localhost:51002"} + + cluster1 := testCluster(t, members, 1, members[0]) + shard1 := raft.AddShard(ctx, cluster1, 1, &IntStateMachine{}) + + cluster2 := testCluster(t, members, 2, members[1]) + shard2 := raft.AddShard(ctx, cluster2, 1, &IntStateMachine{}) + + wg, wctx := errgroup.WithContext(ctx) + wg.Go(func() error { return cluster1.Start(wctx) }) + wg.Go(func() error { return cluster2.Start(wctx) }) + assert.NoError(t, wg.Wait()) + defer cluster1.Stop() + defer cluster2.Stop() - res, err = shard2_1.Query(ctx, 0) - assert.NoError(t, err) - assert.Equal(t, res, int64(3)) + // join to the existing cluster as a new member + cluster3 := testCluster(t, nil, 3, "localhost:51003") + shard3 := raft.AddShard(ctx, cluster3, 1, &IntStateMachine{}) - res, err = shard1_2.Query(ctx, 0) - assert.NoError(t, err) - assert.Equal(t, res, int64(2)) + assert.NoError(t, cluster1.AddMember(ctx, 1, 3, "localhost:51003")) - res, err = shard2_2.Query(ctx, 0) - assert.NoError(t, err) - assert.Equal(t, res, int64(2)) + assert.NoError(t, cluster3.Join(ctx)) + defer cluster3.Stop() + + assert.NoError(t, shard3.Propose(ctx, IntEvent(1))) + + assertShardValue(ctx, t, 1, shard1, shard2, shard3) + + // join through the new member + cluster4 := testCluster(t, nil, 4, "localhost:51004") + shard4 := raft.AddShard(ctx, cluster4, 1, &IntStateMachine{}) + + assert.NoError(t, cluster3.AddMember(ctx, 1, 4, "localhost:51004")) + assert.NoError(t, cluster4.Join(ctx)) + defer cluster4.Stop() + + assert.NoError(t, shard4.Propose(ctx, IntEvent(1))) + + assertShardValue(ctx, t, 2, shard1, shard2, shard3, shard4) } -func testCluster(t *testing.T, members []string, id uint64) *raft.Cluster { +func testCluster(t *testing.T, members []string, id uint64, address string) *raft.Cluster { return raft.New(&raft.RaftConfig{ ReplicaID: id, - RaftAddress: members[id-1], + RaftAddress: address, DataDir: t.TempDir(), InitialMembers: members, HeartbeatRTT: 1, @@ -92,3 +127,13 @@ func testCluster(t *testing.T, members []string, id uint64) *raft.Cluster { CompactionOverhead: 10, }) } + +func assertShardValue(ctx context.Context, t *testing.T, expected int64, shards ...*raft.ShardHandle[IntEvent, int64, int64]) { + t.Helper() + + for _, shard := range shards { + res, err := shard.Query(ctx, 0) + assert.NoError(t, err) + assert.Equal(t, res, expected) + } +}