diff --git a/internal/raft/cluster.go b/internal/raft/cluster.go index d82716e8fe..abcea04df7 100644 --- a/internal/raft/cluster.go +++ b/internal/raft/cluster.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "sync" "time" "github.com/jpillora/backoff" @@ -24,7 +25,7 @@ type RaftConfig struct { ShardReadyTimeout time.Duration `help:"Timeout for shard to be ready" default:"5s"` // Raft configuration RTT time.Duration `help:"Estimated average round trip time between nodes" default:"200ms"` - ElectionTimeoutRTT uint64 `help:"Election timeout RTT as a multiple of RTT" default:"10"` + ElectionRTT uint64 `help:"Election RTT as a multiple of RTT" default:"10"` HeartbeatRTT uint64 `help:"Heartbeat RTT as a multiple of RTT" default:"1"` SnapshotEntries uint64 `help:"Snapshot entries" default:"10"` CompactionOverhead uint64 `help:"Compaction overhead" default:"100"` @@ -91,10 +92,16 @@ type ShardHandle[E Event, Q any, R any] struct { shardID uint64 cluster *Cluster session *client.Session + + mu sync.Mutex } // Propose an event to the shard. func (s *ShardHandle[E, Q, R]) Propose(ctx context.Context, msg E) error { + // client session is not thread safe, so we need to lock + s.mu.Lock() + defer s.mu.Unlock() + s.verifyReady() msgBytes, err := msg.MarshalBinary() @@ -102,13 +109,22 @@ func (s *ShardHandle[E, Q, R]) Propose(ctx context.Context, msg E) error { return fmt.Errorf("failed to marshal event: %w", err) } if s.session == nil { - // use a no-op session for now. This means that a retry on timeout could result into duplicate events. - s.session = s.cluster.nh.GetNoOPSession(s.shardID) + if err := s.cluster.withRetry(ctx, s.shardID, s.cluster.config.ReplicaID, func(ctx context.Context) error { + s.session, err = s.cluster.nh.SyncGetSession(ctx, s.shardID) + return err //nolint:wrapcheck + }); err != nil { + return fmt.Errorf("failed to get session: %w", err) + } } if err := s.cluster.withRetry(ctx, s.shardID, s.cluster.config.ReplicaID, func(ctx context.Context) error { + s.session.PrepareForPropose() _, err := s.cluster.nh.SyncPropose(ctx, s.session, msgBytes) - return err //nolint:wrapcheck + if err != nil { + return err //nolint:wrapcheck + } + s.session.ProposalCompleted() + return nil }); err != nil { return fmt.Errorf("failed to propose event: %w", err) } @@ -177,7 +193,7 @@ func (c *Cluster) start(ctx context.Context, join bool) error { ReplicaID: c.config.ReplicaID, ShardID: shardID, CheckQuorum: true, - ElectionRTT: c.config.ElectionTimeoutRTT, + ElectionRTT: c.config.ElectionRTT, HeartbeatRTT: c.config.HeartbeatRTT, SnapshotEntries: c.config.SnapshotEntries, CompactionOverhead: c.config.CompactionOverhead, @@ -208,21 +224,16 @@ func (c *Cluster) start(ctx context.Context, join bool) error { } // Stop the node host and all shards. -func (c *Cluster) Stop(ctx context.Context) error { - if c.nh == nil { - return nil - } - - for shardID := range c.shards { - if err := c.removeShardMember(ctx, shardID, c.config.ReplicaID); err != nil { - return fmt.Errorf("failed to remove shard (%d) member: %w", shardID, err) +// After this call, all the shard handlers created with this cluster are invalid. +func (c *Cluster) Stop(ctx context.Context) { + if c.nh != nil { + for shardID := range c.shards { + c.removeShardMember(ctx, shardID, c.config.ReplicaID) } + c.nh.Close() + c.nh = nil + c.shards = nil } - - c.nh.Close() - c.nh = nil - - return nil } // AddMember to the cluster. This needs to be called on an existing running cluster member, @@ -241,16 +252,16 @@ func (c *Cluster) AddMember(ctx context.Context, shardID uint64, replicaID uint6 // removeShardMember from the given shard. This removes the given member from the membership group // and blocks until the change has been committed -func (c *Cluster) removeShardMember(ctx context.Context, shardID uint64, replicaID uint64) error { +func (c *Cluster) removeShardMember(ctx context.Context, shardID uint64, replicaID uint64) { logger := log.FromContext(ctx).Scope("raft") logger.Infof("removing replica %d from shard %d", shardID, replicaID) if err := c.withRetry(ctx, shardID, replicaID, func(ctx context.Context) error { return c.nh.SyncRequestDeleteReplica(ctx, shardID, replicaID, 0) }); err != nil { - return fmt.Errorf("failed to remove member: %w", err) + // This can happen if the cluster is shutting down and no longer has quorum. + logger.Warnf("removing replica %d from shard %d failed: %s", replicaID, shardID, err) } - return nil } // withTimeout runs an async dragonboat call and blocks until it succeeds or the context is cancelled. @@ -268,7 +279,7 @@ func (c *Cluster) withRetry(ctx context.Context, shardID, replicaID uint64, f fu // Timeout for the proposal to reach the leader and reach a quorum. // If the leader is not available, the proposal will time out, in which case // we retry the operation. - timeout := time.Duration(c.config.ElectionTimeoutRTT) * c.config.RTT + timeout := time.Duration(c.config.ElectionRTT) * c.config.RTT ctx, cancel := context.WithTimeout(ctx, timeout) defer cancel() diff --git a/internal/raft/cluster_test.go b/internal/raft/cluster_test.go index 5545325f75..d95c4b1b71 100644 --- a/internal/raft/cluster_test.go +++ b/internal/raft/cluster_test.go @@ -45,7 +45,7 @@ func (s *IntStateMachine) Close() error { return nil } func TestCluster(t *testing.T) { ctx := log.ContextWithNewDefaultLogger(context.Background()) ctx, cancel := context.WithDeadline(ctx, time.Now().Add(20*time.Second)) - defer cancel() + t.Cleanup(cancel) members, err := local.FreeTCPAddresses(2) assert.NoError(t, err) @@ -64,8 +64,10 @@ func TestCluster(t *testing.T) { wg.Go(func() error { return cluster1.Start(wctx) }) wg.Go(func() error { return cluster2.Start(wctx) }) assert.NoError(t, wg.Wait()) - defer cluster1.Stop(ctx) //nolint:errcheck - defer cluster2.Stop(ctx) //nolint:errcheck + t.Cleanup(func() { + cluster1.Stop(ctx) + cluster2.Stop(ctx) + }) assert.NoError(t, shard1_1.Propose(ctx, IntEvent(1))) assert.NoError(t, shard2_1.Propose(ctx, IntEvent(2))) @@ -80,7 +82,7 @@ func TestCluster(t *testing.T) { func TestJoiningExistingCluster(t *testing.T) { ctx := log.ContextWithNewDefaultLogger(context.Background()) ctx, cancel := context.WithDeadline(ctx, time.Now().Add(20*time.Second)) - defer cancel() + t.Cleanup(cancel) members, err := local.FreeTCPAddresses(4) assert.NoError(t, err) @@ -97,8 +99,10 @@ func TestJoiningExistingCluster(t *testing.T) { wg.Go(func() error { return cluster1.Start(wctx) }) wg.Go(func() error { return cluster2.Start(wctx) }) assert.NoError(t, wg.Wait()) - defer cluster1.Stop(ctx) //nolint:errcheck - defer cluster2.Stop(ctx) //nolint:errcheck + t.Cleanup(func() { + cluster1.Stop(ctx) + cluster2.Stop(ctx) + }) t.Log("join to the existing cluster as a new member") builder3 := testBuilder(t, nil, 3, members[2].String()) @@ -131,7 +135,7 @@ func TestJoiningExistingCluster(t *testing.T) { func TestLeavingCluster(t *testing.T) { ctx := log.ContextWithNewDefaultLogger(context.Background()) ctx, cancel := context.WithDeadline(ctx, time.Now().Add(20*time.Second)) - defer cancel() + t.Cleanup(cancel) members, err := local.FreeTCPAddresses(3) assert.NoError(t, err) @@ -151,16 +155,18 @@ func TestLeavingCluster(t *testing.T) { wg.Go(func() error { return cluster2.Start(wctx) }) wg.Go(func() error { return cluster3.Start(wctx) }) assert.NoError(t, wg.Wait()) - defer cluster1.Stop(ctx) //nolint:errcheck - defer cluster2.Stop(ctx) //nolint:errcheck - defer cluster3.Stop(ctx) //nolint:errcheck + t.Cleanup(func() { + cluster1.Stop(ctx) + cluster2.Stop(ctx) + cluster3.Stop(ctx) + }) t.Log("proposing event") assert.NoError(t, shard1.Propose(ctx, IntEvent(1))) assertShardValue(ctx, t, 1, shard1, shard2, shard3) t.Log("removing member") - assert.NoError(t, cluster1.Stop(ctx)) + cluster1.Stop(ctx) t.Log("proposing event after removal") assert.NoError(t, shard2.Propose(ctx, IntEvent(1))) @@ -179,7 +185,7 @@ func testBuilder(t *testing.T, addresses []*net.TCPAddr, id uint64, address stri DataDir: t.TempDir(), InitialMembers: members, HeartbeatRTT: 1, - ElectionTimeoutRTT: 10, + ElectionRTT: 5, SnapshotEntries: 10, CompactionOverhead: 10, RTT: 10 * time.Millisecond, diff --git a/internal/raft/eventview_test.go b/internal/raft/eventview_test.go index 6055785b62..d9bfd3abda 100644 --- a/internal/raft/eventview_test.go +++ b/internal/raft/eventview_test.go @@ -46,7 +46,7 @@ func (v *IntSumView) UnmarshalBinary(data []byte) error { func TestEventView(t *testing.T) { ctx := log.ContextWithNewDefaultLogger(context.Background()) ctx, cancel := context.WithDeadline(ctx, time.Now().Add(60*time.Second)) - defer cancel() + t.Cleanup(cancel) members, err := local.FreeTCPAddresses(2) assert.NoError(t, err) @@ -63,8 +63,10 @@ func TestEventView(t *testing.T) { eg.Go(func() error { return cluster1.Start(wctx) }) eg.Go(func() error { return cluster2.Start(wctx) }) assert.NoError(t, eg.Wait()) - defer cluster1.Stop(ctx) //nolint:errcheck - defer cluster2.Stop(ctx) //nolint:errcheck + t.Cleanup(func() { + cluster1.Stop(ctx) + cluster2.Stop(ctx) + }) assert.NoError(t, view1.Publish(ctx, IntStreamEvent{Value: 1}))