From 8bcb8506a9e5a0da47ad4d0fd6f64331e87a1a8e Mon Sep 17 00:00:00 2001 From: FZambia Date: Tue, 3 Dec 2024 19:49:10 +0200 Subject: [PATCH] kafka: fix possibility to loose records under load --- internal/consuming/kafka.go | 78 ++++++++----- internal/consuming/kafka_test.go | 187 ++++++++++++++++++++++++++++++- 2 files changed, 234 insertions(+), 31 deletions(-) diff --git a/internal/consuming/kafka.go b/internal/consuming/kafka.go index 44cf1cbbd..7720e43e3 100644 --- a/internal/consuming/kafka.go +++ b/internal/consuming/kafka.go @@ -42,6 +42,10 @@ type KafkaConfig struct { // will pause fetching records from Kafka. By default, this is 16. // Set to -1 to use non-buffered channel. PartitionBufferSize int `mapstructure:"partition_buffer_size" json:"partition_buffer_size"` + + // FetchMaxBytes is the maximum number of bytes to fetch from Kafka in a single request. + // If not set the default 50MB is used. + FetchMaxBytes int32 `mapstructure:"fetch_max_bytes" json:"fetch_max_bytes"` } type topicPartition struct { @@ -142,6 +146,9 @@ func (c *KafkaConsumer) initClient() (*kgo.Client, error) { kgo.ClientID(kafkaClientID), kgo.InstanceID(c.getInstanceID()), } + if c.config.FetchMaxBytes > 0 { + opts = append(opts, kgo.FetchMaxBytes(c.config.FetchMaxBytes)) + } if c.config.TLS { tlsOptionsMap, err := c.config.TLSOptions.ToMap() if err != nil { @@ -284,12 +291,19 @@ func (c *KafkaConsumer) pollUntilFatal(ctx context.Context) error { return fmt.Errorf("poll error: %w", errors.Join(errs...)) } + pausedTopicPartitions := map[topicPartition]struct{}{} fetches.EachPartition(func(p kgo.FetchTopicPartition) { if len(p.Records) == 0 { return } tp := topicPartition{p.Topic, p.Partition} + if _, paused := pausedTopicPartitions[tp]; paused { + // We have already paused this partition during this poll, so we should not + // process records from it anymore. We will resume partition processing with the + // correct offset soon, after we have space in recs buffer. + return + } // Since we are using BlockRebalanceOnPoll, we can be // sure this partition consumer exists: @@ -310,6 +324,12 @@ func (c *KafkaConsumer) pollUntilFatal(ctx context.Context) error { // keeping records in memory and blocking rebalance. Resume will be called after // records are processed by c.consumers[tp]. c.client.PauseFetchPartitions(partitionsToPause) + pausedTopicPartitions[tp] = struct{}{} + // To poll next time since correct offset we need to set it manually to the offset of + // the first record in the batch. Otherwise, next poll will return the next record batch, + // and we will lose the current one. + epochOffset := kgo.EpochOffset{Epoch: -1, Offset: p.Records[0].Offset} + c.client.SetOffsets(map[string]map[int32]kgo.EpochOffset{p.Topic: {p.Partition: epochOffset}}) } }) c.client.AllowRebalance() @@ -355,15 +375,25 @@ func (c *KafkaConsumer) assigned(ctx context.Context, cl *kgo.Client, assigned m } for topic, partitions := range assigned { for _, partition := range partitions { + quitCh := make(chan struct{}) + partitionCtx, cancel := context.WithCancel(ctx) + go func() { + select { + case <-ctx.Done(): + cancel() + case <-quitCh: + cancel() + } + }() pc := &partitionConsumer{ - clientCtx: ctx, - dispatcher: c.dispatcher, - logger: c.logger, - cl: cl, - topic: topic, - partition: partition, - - quit: make(chan struct{}), + partitionCtx: partitionCtx, + dispatcher: c.dispatcher, + logger: c.logger, + cl: cl, + topic: topic, + partition: partition, + + quit: quitCh, done: make(chan struct{}), recs: make(chan kgo.FetchTopicPartition, bufferSize), } @@ -407,12 +437,12 @@ func (c *KafkaConsumer) killConsumers(lost map[string][]int32) { } type partitionConsumer struct { - clientCtx context.Context - dispatcher Dispatcher - logger Logger - cl *kgo.Client - topic string - partition int32 + partitionCtx context.Context + dispatcher Dispatcher + logger Logger + cl *kgo.Client + topic string + partition int32 quit chan struct{} done chan struct{} @@ -422,9 +452,7 @@ type partitionConsumer struct { func (pc *partitionConsumer) processRecords(records []*kgo.Record) { for _, record := range records { select { - case <-pc.clientCtx.Done(): - return - case <-pc.quit: + case <-pc.partitionCtx.Done(): return default: } @@ -432,7 +460,7 @@ func (pc *partitionConsumer) processRecords(records []*kgo.Record) { var e KafkaJSONEvent err := json.Unmarshal(record.Value, &e) if err != nil { - pc.logger.Log(centrifuge.NewLogEntry(centrifuge.LogLevelError, "error unmarshalling event from Kafka", map[string]any{"error": err.Error(), "topic": record.Topic, "partition": record.Partition})) + pc.logger.Log(centrifuge.NewLogEntry(centrifuge.LogLevelError, "error unmarshalling record value from Kafka", map[string]any{"error": err.Error(), "topic": record.Topic, "partition": record.Partition})) pc.cl.MarkCommitRecords(record) continue } @@ -440,22 +468,20 @@ func (pc *partitionConsumer) processRecords(records []*kgo.Record) { var backoffDuration time.Duration = 0 retries := 0 for { - err := pc.dispatcher.Dispatch(pc.clientCtx, e.Method, e.Payload) + err := pc.dispatcher.Dispatch(pc.partitionCtx, e.Method, e.Payload) if err == nil { if retries > 0 { - pc.logger.Log(centrifuge.NewLogEntry(centrifuge.LogLevelInfo, "OK processing events after errors", map[string]any{})) + pc.logger.Log(centrifuge.NewLogEntry(centrifuge.LogLevelInfo, "OK processing record after errors", map[string]any{})) } pc.cl.MarkCommitRecords(record) break } retries++ backoffDuration = getNextBackoffDuration(backoffDuration, retries) - pc.logger.Log(centrifuge.NewLogEntry(centrifuge.LogLevelError, "error processing consumed event", map[string]any{"error": err.Error(), "method": e.Method, "nextAttemptIn": backoffDuration.String()})) + pc.logger.Log(centrifuge.NewLogEntry(centrifuge.LogLevelError, "error processing consumed record", map[string]any{"error": err.Error(), "method": e.Method, "next_attempt_in": backoffDuration.String()})) select { case <-time.After(backoffDuration): - case <-pc.quit: - return - case <-pc.clientCtx.Done(): + case <-pc.partitionCtx.Done(): return } } @@ -471,9 +497,7 @@ func (pc *partitionConsumer) consume() { defer resumeConsuming() for { select { - case <-pc.clientCtx.Done(): - return - case <-pc.quit: + case <-pc.partitionCtx.Done(): return case p := <-pc.recs: pc.processRecords(p.Records) diff --git a/internal/consuming/kafka_test.go b/internal/consuming/kafka_test.go index 60e782aff..cbb61dad6 100644 --- a/internal/consuming/kafka_test.go +++ b/internal/consuming/kafka_test.go @@ -7,7 +7,9 @@ import ( "encoding/json" "errors" "fmt" + "strconv" "strings" + "sync/atomic" "testing" "time" @@ -44,15 +46,26 @@ func (m *MockLogger) Log(_ centrifuge.LogEntry) { // Implement mock logic, e.g., storing log entries for assertions } +func produceManyRecords(records ...*kgo.Record) error { + client, err := kgo.NewClient(kgo.SeedBrokers(testKafkaBrokerURL)) + if err != nil { + return fmt.Errorf("failed to create Kafka client: %w", err) + } + defer client.Close() + err = client.ProduceSync(context.Background(), records...).FirstErr() + if err != nil { + return fmt.Errorf("failed to produce message: %w", err) + } + return nil +} + func produceTestMessage(topic string, message []byte) error { - // Create a new client client, err := kgo.NewClient(kgo.SeedBrokers(testKafkaBrokerURL)) if err != nil { return fmt.Errorf("failed to create Kafka client: %w", err) } defer client.Close() - // Produce a message err = client.ProduceSync(context.Background(), &kgo.Record{Topic: topic, Partition: 0, Value: message}).FirstErr() if err != nil { return fmt.Errorf("failed to produce message: %w", err) @@ -61,7 +74,6 @@ func produceTestMessage(topic string, message []byte) error { } func produceTestMessageToPartition(topic string, message []byte, partition int32) error { - // Create a new client. client, err := kgo.NewClient( kgo.SeedBrokers(testKafkaBrokerURL), kgo.RecordPartitioner(kgo.ManualPartitioner()), @@ -71,7 +83,6 @@ func produceTestMessageToPartition(topic string, message []byte, partition int32 } defer client.Close() - // Produce a message until we hit desired partition. res := client.ProduceSync(context.Background(), &kgo.Record{ Topic: topic, Partition: partition, Value: message}) if res.FirstErr() != nil { @@ -427,3 +438,171 @@ func TestKafkaConsumer_BlockedPartitionDoesNotBlockAnotherPartition(t *testing.T }) } } + +func TestKafkaConsumer_PausePartitions(t *testing.T) { + t.Parallel() + testKafkaTopic := "consumer_test_" + uuid.New().String() + testPayload1 := []byte(`{"key":"value1"}`) + testPayload2 := []byte(`{"key":"value2"}`) + testPayload3 := []byte(`{"key":"value3"}`) + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + err := createTestTopic(ctx, testKafkaTopic, 1, 1) + require.NoError(t, err) + + event1Received := make(chan struct{}) + event2Received := make(chan struct{}) + event3Received := make(chan struct{}) + consumerClosed := make(chan struct{}) + doneCh := make(chan struct{}) + + config := KafkaConfig{ + Brokers: []string{testKafkaBrokerURL}, + Topics: []string{testKafkaTopic}, + ConsumerGroup: uuid.New().String(), + PartitionBufferSize: -1, + } + + numCalls := 0 + + mockDispatcher := &MockDispatcher{ + onDispatch: func(ctx context.Context, method string, data []byte) error { + numCalls++ + if numCalls == 1 { + close(event1Received) + time.Sleep(5 * time.Second) + return nil + } else if numCalls == 2 { + close(event2Received) + return nil + } + close(event3Received) + return nil + }, + } + consumer, err := NewKafkaConsumer("test", uuid.NewString(), &MockLogger{}, mockDispatcher, config) + require.NoError(t, err) + + go func() { + err = produceTestMessage(testKafkaTopic, testPayload1) + require.NoError(t, err) + <-event1Received + // At this point message 1 is being processed and the next produced message will + // cause a partition pause. + err = produceTestMessage(testKafkaTopic, testPayload2) + require.NoError(t, err) + <-event2Received + err = produceTestMessage(testKafkaTopic, testPayload3) + require.NoError(t, err) + }() + + go func() { + err := consumer.Run(ctx) + require.ErrorIs(t, err, context.Canceled) + close(consumerClosed) + }() + + waitCh(t, event1Received, 30*time.Second, "timeout waiting for event 1") + waitCh(t, event2Received, 30*time.Second, "timeout waiting for event 2") + waitCh(t, event3Received, 30*time.Second, "timeout waiting for event 3") + cancel() + waitCh(t, consumerClosed, 30*time.Second, "timeout waiting for consumer closed") + close(doneCh) +} + +func TestKafkaConsumer_WorksCorrectlyInLoadedTopic(t *testing.T) { + t.Skip() + t.Parallel() + + testCases := []struct { + numPartitions int32 + numMessages int + partitionBuffer int + }{ + //{numPartitions: 1, numMessages: 1000, partitionBuffer: -1}, + //{numPartitions: 1, numMessages: 1000, partitionBuffer: 1}, + //{numPartitions: 10, numMessages: 10000, partitionBuffer: -1}, + {numPartitions: 10, numMessages: 10000, partitionBuffer: 1}, + } + + for _, tc := range testCases { + name := fmt.Sprintf("partitions=%d,messages=%d,buffer=%d", tc.numPartitions, tc.numMessages, tc.partitionBuffer) + t.Run(name, func(t *testing.T) { + testKafkaTopic := "consumer_test_" + uuid.New().String() + + ctx, cancel := context.WithTimeout(context.Background(), 300*time.Second) + defer cancel() + + err := createTestTopic(ctx, testKafkaTopic, tc.numPartitions, 1) + require.NoError(t, err) + + consumerClosed := make(chan struct{}) + doneCh := make(chan struct{}) + + numMessages := tc.numMessages + messageCh := make(chan struct{}, numMessages) + + mockDispatcher := &MockDispatcher{ + onDispatch: func(ctx context.Context, method string, data []byte) error { + // Emulate delay due to some work. + time.Sleep(20 * time.Millisecond) + messageCh <- struct{}{} + return nil + }, + } + config := KafkaConfig{ + Brokers: []string{testKafkaBrokerURL}, + Topics: []string{testKafkaTopic}, + ConsumerGroup: uuid.New().String(), + PartitionBufferSize: tc.partitionBuffer, + } + consumer, err := NewKafkaConsumer("test", uuid.NewString(), &MockLogger{}, mockDispatcher, config) + require.NoError(t, err) + + var records []*kgo.Record + for i := 0; i < numMessages; i++ { + records = append(records, &kgo.Record{Topic: testKafkaTopic, Value: []byte(`{"hello": "` + strconv.Itoa(i) + `"}`)}) + if (i+1)%100 == 0 { + err = produceManyRecords(records...) + if err != nil { + t.Fatal(err) + } + records = nil + t.Logf("produced %d messages", i+1) + } + } + + t.Logf("all messages produced, 3, 2, 1, go!") + time.Sleep(time.Second) + + go func() { + err := consumer.Run(ctx) + require.ErrorIs(t, err, context.Canceled) + close(consumerClosed) + }() + + var numProcessed int64 + go func() { + for { + select { + case <-ctx.Done(): + return + case <-time.After(time.Second): + t.Logf("processed %d messages", atomic.LoadInt64(&numProcessed)) + } + } + }() + + for i := 0; i < numMessages; i++ { + <-messageCh + atomic.AddInt64(&numProcessed, 1) + } + t.Logf("all messages processed") + cancel() + waitCh(t, consumerClosed, 30*time.Second, "timeout waiting for consumer closed") + close(doneCh) + }) + } +}