Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Kafka: Fix potential loss of records under load #917

Merged
merged 1 commit into from
Dec 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 51 additions & 27 deletions internal/consuming/kafka.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ type KafkaConfig struct {
// will pause fetching records from Kafka. By default, this is 16.
// Set to -1 to use non-buffered channel.
PartitionBufferSize int `mapstructure:"partition_buffer_size" json:"partition_buffer_size"`

// FetchMaxBytes is the maximum number of bytes to fetch from Kafka in a single request.
// If not set the default 50MB is used.
FetchMaxBytes int32 `mapstructure:"fetch_max_bytes" json:"fetch_max_bytes"`
}

type topicPartition struct {
Expand Down Expand Up @@ -142,6 +146,9 @@ func (c *KafkaConsumer) initClient() (*kgo.Client, error) {
kgo.ClientID(kafkaClientID),
kgo.InstanceID(c.getInstanceID()),
}
if c.config.FetchMaxBytes > 0 {
opts = append(opts, kgo.FetchMaxBytes(c.config.FetchMaxBytes))
}
if c.config.TLS {
tlsOptionsMap, err := c.config.TLSOptions.ToMap()
if err != nil {
Expand Down Expand Up @@ -284,12 +291,19 @@ func (c *KafkaConsumer) pollUntilFatal(ctx context.Context) error {
return fmt.Errorf("poll error: %w", errors.Join(errs...))
}

pausedTopicPartitions := map[topicPartition]struct{}{}
fetches.EachPartition(func(p kgo.FetchTopicPartition) {
if len(p.Records) == 0 {
return
}

tp := topicPartition{p.Topic, p.Partition}
if _, paused := pausedTopicPartitions[tp]; paused {
// We have already paused this partition during this poll, so we should not
// process records from it anymore. We will resume partition processing with the
// correct offset soon, after we have space in recs buffer.
return
}

// Since we are using BlockRebalanceOnPoll, we can be
// sure this partition consumer exists:
Expand All @@ -310,6 +324,12 @@ func (c *KafkaConsumer) pollUntilFatal(ctx context.Context) error {
// keeping records in memory and blocking rebalance. Resume will be called after
// records are processed by c.consumers[tp].
c.client.PauseFetchPartitions(partitionsToPause)
pausedTopicPartitions[tp] = struct{}{}
// To poll next time since correct offset we need to set it manually to the offset of
// the first record in the batch. Otherwise, next poll will return the next record batch,
// and we will lose the current one.
epochOffset := kgo.EpochOffset{Epoch: -1, Offset: p.Records[0].Offset}
c.client.SetOffsets(map[string]map[int32]kgo.EpochOffset{p.Topic: {p.Partition: epochOffset}})
}
})
c.client.AllowRebalance()
Expand Down Expand Up @@ -355,15 +375,25 @@ func (c *KafkaConsumer) assigned(ctx context.Context, cl *kgo.Client, assigned m
}
for topic, partitions := range assigned {
for _, partition := range partitions {
quitCh := make(chan struct{})
partitionCtx, cancel := context.WithCancel(ctx)
go func() {
select {
case <-ctx.Done():
cancel()
case <-quitCh:
cancel()
}
}()
pc := &partitionConsumer{
clientCtx: ctx,
dispatcher: c.dispatcher,
logger: c.logger,
cl: cl,
topic: topic,
partition: partition,

quit: make(chan struct{}),
partitionCtx: partitionCtx,
dispatcher: c.dispatcher,
logger: c.logger,
cl: cl,
topic: topic,
partition: partition,

quit: quitCh,
done: make(chan struct{}),
recs: make(chan kgo.FetchTopicPartition, bufferSize),
}
Expand Down Expand Up @@ -407,12 +437,12 @@ func (c *KafkaConsumer) killConsumers(lost map[string][]int32) {
}

type partitionConsumer struct {
clientCtx context.Context
dispatcher Dispatcher
logger Logger
cl *kgo.Client
topic string
partition int32
partitionCtx context.Context
dispatcher Dispatcher
logger Logger
cl *kgo.Client
topic string
partition int32

quit chan struct{}
done chan struct{}
Expand All @@ -422,40 +452,36 @@ type partitionConsumer struct {
func (pc *partitionConsumer) processRecords(records []*kgo.Record) {
for _, record := range records {
select {
case <-pc.clientCtx.Done():
return
case <-pc.quit:
case <-pc.partitionCtx.Done():
return
default:
}

var e KafkaJSONEvent
err := json.Unmarshal(record.Value, &e)
if err != nil {
pc.logger.Log(centrifuge.NewLogEntry(centrifuge.LogLevelError, "error unmarshalling event from Kafka", map[string]any{"error": err.Error(), "topic": record.Topic, "partition": record.Partition}))
pc.logger.Log(centrifuge.NewLogEntry(centrifuge.LogLevelError, "error unmarshalling record value from Kafka", map[string]any{"error": err.Error(), "topic": record.Topic, "partition": record.Partition}))
pc.cl.MarkCommitRecords(record)
continue
}

var backoffDuration time.Duration = 0
retries := 0
for {
err := pc.dispatcher.Dispatch(pc.clientCtx, e.Method, e.Payload)
err := pc.dispatcher.Dispatch(pc.partitionCtx, e.Method, e.Payload)
if err == nil {
if retries > 0 {
pc.logger.Log(centrifuge.NewLogEntry(centrifuge.LogLevelInfo, "OK processing events after errors", map[string]any{}))
pc.logger.Log(centrifuge.NewLogEntry(centrifuge.LogLevelInfo, "OK processing record after errors", map[string]any{}))
}
pc.cl.MarkCommitRecords(record)
break
}
retries++
backoffDuration = getNextBackoffDuration(backoffDuration, retries)
pc.logger.Log(centrifuge.NewLogEntry(centrifuge.LogLevelError, "error processing consumed event", map[string]any{"error": err.Error(), "method": e.Method, "nextAttemptIn": backoffDuration.String()}))
pc.logger.Log(centrifuge.NewLogEntry(centrifuge.LogLevelError, "error processing consumed record", map[string]any{"error": err.Error(), "method": e.Method, "next_attempt_in": backoffDuration.String()}))
select {
case <-time.After(backoffDuration):
case <-pc.quit:
return
case <-pc.clientCtx.Done():
case <-pc.partitionCtx.Done():
return
}
}
Expand All @@ -471,9 +497,7 @@ func (pc *partitionConsumer) consume() {
defer resumeConsuming()
for {
select {
case <-pc.clientCtx.Done():
return
case <-pc.quit:
case <-pc.partitionCtx.Done():
return
case p := <-pc.recs:
pc.processRecords(p.Records)
Expand Down
187 changes: 183 additions & 4 deletions internal/consuming/kafka_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@ import (
"encoding/json"
"errors"
"fmt"
"strconv"
"strings"
"sync/atomic"
"testing"
"time"

Expand Down Expand Up @@ -44,15 +46,26 @@ func (m *MockLogger) Log(_ centrifuge.LogEntry) {
// Implement mock logic, e.g., storing log entries for assertions
}

func produceManyRecords(records ...*kgo.Record) error {
client, err := kgo.NewClient(kgo.SeedBrokers(testKafkaBrokerURL))
if err != nil {
return fmt.Errorf("failed to create Kafka client: %w", err)
}
defer client.Close()
err = client.ProduceSync(context.Background(), records...).FirstErr()
if err != nil {
return fmt.Errorf("failed to produce message: %w", err)
}
return nil
}

func produceTestMessage(topic string, message []byte) error {
// Create a new client
client, err := kgo.NewClient(kgo.SeedBrokers(testKafkaBrokerURL))
if err != nil {
return fmt.Errorf("failed to create Kafka client: %w", err)
}
defer client.Close()

// Produce a message
err = client.ProduceSync(context.Background(), &kgo.Record{Topic: topic, Partition: 0, Value: message}).FirstErr()
if err != nil {
return fmt.Errorf("failed to produce message: %w", err)
Expand All @@ -61,7 +74,6 @@ func produceTestMessage(topic string, message []byte) error {
}

func produceTestMessageToPartition(topic string, message []byte, partition int32) error {
// Create a new client.
client, err := kgo.NewClient(
kgo.SeedBrokers(testKafkaBrokerURL),
kgo.RecordPartitioner(kgo.ManualPartitioner()),
Expand All @@ -71,7 +83,6 @@ func produceTestMessageToPartition(topic string, message []byte, partition int32
}
defer client.Close()

// Produce a message until we hit desired partition.
res := client.ProduceSync(context.Background(), &kgo.Record{
Topic: topic, Partition: partition, Value: message})
if res.FirstErr() != nil {
Expand Down Expand Up @@ -427,3 +438,171 @@ func TestKafkaConsumer_BlockedPartitionDoesNotBlockAnotherPartition(t *testing.T
})
}
}

func TestKafkaConsumer_PausePartitions(t *testing.T) {
t.Parallel()
testKafkaTopic := "consumer_test_" + uuid.New().String()
testPayload1 := []byte(`{"key":"value1"}`)
testPayload2 := []byte(`{"key":"value2"}`)
testPayload3 := []byte(`{"key":"value3"}`)

ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()

err := createTestTopic(ctx, testKafkaTopic, 1, 1)
require.NoError(t, err)

event1Received := make(chan struct{})
event2Received := make(chan struct{})
event3Received := make(chan struct{})
consumerClosed := make(chan struct{})
doneCh := make(chan struct{})

config := KafkaConfig{
Brokers: []string{testKafkaBrokerURL},
Topics: []string{testKafkaTopic},
ConsumerGroup: uuid.New().String(),
PartitionBufferSize: -1,
}

numCalls := 0

mockDispatcher := &MockDispatcher{
onDispatch: func(ctx context.Context, method string, data []byte) error {
numCalls++
if numCalls == 1 {
close(event1Received)
time.Sleep(5 * time.Second)
return nil
} else if numCalls == 2 {
close(event2Received)
return nil
}
close(event3Received)
return nil
},
}
consumer, err := NewKafkaConsumer("test", uuid.NewString(), &MockLogger{}, mockDispatcher, config)
require.NoError(t, err)

go func() {
err = produceTestMessage(testKafkaTopic, testPayload1)
require.NoError(t, err)
<-event1Received
// At this point message 1 is being processed and the next produced message will
// cause a partition pause.
err = produceTestMessage(testKafkaTopic, testPayload2)
require.NoError(t, err)
<-event2Received
err = produceTestMessage(testKafkaTopic, testPayload3)
require.NoError(t, err)
}()

go func() {
err := consumer.Run(ctx)
require.ErrorIs(t, err, context.Canceled)
close(consumerClosed)
}()

waitCh(t, event1Received, 30*time.Second, "timeout waiting for event 1")
waitCh(t, event2Received, 30*time.Second, "timeout waiting for event 2")
waitCh(t, event3Received, 30*time.Second, "timeout waiting for event 3")
cancel()
waitCh(t, consumerClosed, 30*time.Second, "timeout waiting for consumer closed")
close(doneCh)
}

func TestKafkaConsumer_WorksCorrectlyInLoadedTopic(t *testing.T) {
t.Skip()
t.Parallel()

testCases := []struct {
numPartitions int32
numMessages int
partitionBuffer int
}{
//{numPartitions: 1, numMessages: 1000, partitionBuffer: -1},
//{numPartitions: 1, numMessages: 1000, partitionBuffer: 1},
//{numPartitions: 10, numMessages: 10000, partitionBuffer: -1},
{numPartitions: 10, numMessages: 10000, partitionBuffer: 1},
}

for _, tc := range testCases {
name := fmt.Sprintf("partitions=%d,messages=%d,buffer=%d", tc.numPartitions, tc.numMessages, tc.partitionBuffer)
t.Run(name, func(t *testing.T) {
testKafkaTopic := "consumer_test_" + uuid.New().String()

ctx, cancel := context.WithTimeout(context.Background(), 300*time.Second)
defer cancel()

err := createTestTopic(ctx, testKafkaTopic, tc.numPartitions, 1)
require.NoError(t, err)

consumerClosed := make(chan struct{})
doneCh := make(chan struct{})

numMessages := tc.numMessages
messageCh := make(chan struct{}, numMessages)

mockDispatcher := &MockDispatcher{
onDispatch: func(ctx context.Context, method string, data []byte) error {
// Emulate delay due to some work.
time.Sleep(20 * time.Millisecond)
messageCh <- struct{}{}
return nil
},
}
config := KafkaConfig{
Brokers: []string{testKafkaBrokerURL},
Topics: []string{testKafkaTopic},
ConsumerGroup: uuid.New().String(),
PartitionBufferSize: tc.partitionBuffer,
}
consumer, err := NewKafkaConsumer("test", uuid.NewString(), &MockLogger{}, mockDispatcher, config)
require.NoError(t, err)

var records []*kgo.Record
for i := 0; i < numMessages; i++ {
records = append(records, &kgo.Record{Topic: testKafkaTopic, Value: []byte(`{"hello": "` + strconv.Itoa(i) + `"}`)})
if (i+1)%100 == 0 {
err = produceManyRecords(records...)
if err != nil {
t.Fatal(err)
}
records = nil
t.Logf("produced %d messages", i+1)
}
}

t.Logf("all messages produced, 3, 2, 1, go!")
time.Sleep(time.Second)

go func() {
err := consumer.Run(ctx)
require.ErrorIs(t, err, context.Canceled)
close(consumerClosed)
}()

var numProcessed int64
go func() {
for {
select {
case <-ctx.Done():
return
case <-time.After(time.Second):
t.Logf("processed %d messages", atomic.LoadInt64(&numProcessed))
}
}
}()

for i := 0; i < numMessages; i++ {
<-messageCh
atomic.AddInt64(&numProcessed, 1)
}
t.Logf("all messages processed")
cancel()
waitCh(t, consumerClosed, 30*time.Second, "timeout waiting for consumer closed")
close(doneCh)
})
}
}
Loading