Skip to content

Commit

Permalink
kafka: fix possibility to loose records under load
Browse files Browse the repository at this point in the history
  • Loading branch information
FZambia committed Dec 3, 2024
1 parent 0bde1f4 commit 8bcb850
Show file tree
Hide file tree
Showing 2 changed files with 234 additions and 31 deletions.
78 changes: 51 additions & 27 deletions internal/consuming/kafka.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ type KafkaConfig struct {
// will pause fetching records from Kafka. By default, this is 16.
// Set to -1 to use non-buffered channel.
PartitionBufferSize int `mapstructure:"partition_buffer_size" json:"partition_buffer_size"`

// FetchMaxBytes is the maximum number of bytes to fetch from Kafka in a single request.
// If not set the default 50MB is used.
FetchMaxBytes int32 `mapstructure:"fetch_max_bytes" json:"fetch_max_bytes"`
}

type topicPartition struct {
Expand Down Expand Up @@ -142,6 +146,9 @@ func (c *KafkaConsumer) initClient() (*kgo.Client, error) {
kgo.ClientID(kafkaClientID),
kgo.InstanceID(c.getInstanceID()),
}
if c.config.FetchMaxBytes > 0 {
opts = append(opts, kgo.FetchMaxBytes(c.config.FetchMaxBytes))
}
if c.config.TLS {
tlsOptionsMap, err := c.config.TLSOptions.ToMap()
if err != nil {
Expand Down Expand Up @@ -284,12 +291,19 @@ func (c *KafkaConsumer) pollUntilFatal(ctx context.Context) error {
return fmt.Errorf("poll error: %w", errors.Join(errs...))
}

pausedTopicPartitions := map[topicPartition]struct{}{}
fetches.EachPartition(func(p kgo.FetchTopicPartition) {
if len(p.Records) == 0 {
return
}

tp := topicPartition{p.Topic, p.Partition}
if _, paused := pausedTopicPartitions[tp]; paused {
// We have already paused this partition during this poll, so we should not
// process records from it anymore. We will resume partition processing with the
// correct offset soon, after we have space in recs buffer.
return
}

// Since we are using BlockRebalanceOnPoll, we can be
// sure this partition consumer exists:
Expand All @@ -310,6 +324,12 @@ func (c *KafkaConsumer) pollUntilFatal(ctx context.Context) error {
// keeping records in memory and blocking rebalance. Resume will be called after
// records are processed by c.consumers[tp].
c.client.PauseFetchPartitions(partitionsToPause)
pausedTopicPartitions[tp] = struct{}{}
// To poll next time since correct offset we need to set it manually to the offset of
// the first record in the batch. Otherwise, next poll will return the next record batch,
// and we will lose the current one.
epochOffset := kgo.EpochOffset{Epoch: -1, Offset: p.Records[0].Offset}
c.client.SetOffsets(map[string]map[int32]kgo.EpochOffset{p.Topic: {p.Partition: epochOffset}})
}
})
c.client.AllowRebalance()
Expand Down Expand Up @@ -355,15 +375,25 @@ func (c *KafkaConsumer) assigned(ctx context.Context, cl *kgo.Client, assigned m
}
for topic, partitions := range assigned {
for _, partition := range partitions {
quitCh := make(chan struct{})
partitionCtx, cancel := context.WithCancel(ctx)
go func() {
select {
case <-ctx.Done():
cancel()
case <-quitCh:
cancel()
}
}()
pc := &partitionConsumer{
clientCtx: ctx,
dispatcher: c.dispatcher,
logger: c.logger,
cl: cl,
topic: topic,
partition: partition,

quit: make(chan struct{}),
partitionCtx: partitionCtx,
dispatcher: c.dispatcher,
logger: c.logger,
cl: cl,
topic: topic,
partition: partition,

quit: quitCh,
done: make(chan struct{}),
recs: make(chan kgo.FetchTopicPartition, bufferSize),
}
Expand Down Expand Up @@ -407,12 +437,12 @@ func (c *KafkaConsumer) killConsumers(lost map[string][]int32) {
}

type partitionConsumer struct {
clientCtx context.Context
dispatcher Dispatcher
logger Logger
cl *kgo.Client
topic string
partition int32
partitionCtx context.Context
dispatcher Dispatcher
logger Logger
cl *kgo.Client
topic string
partition int32

quit chan struct{}
done chan struct{}
Expand All @@ -422,40 +452,36 @@ type partitionConsumer struct {
func (pc *partitionConsumer) processRecords(records []*kgo.Record) {
for _, record := range records {
select {
case <-pc.clientCtx.Done():
return
case <-pc.quit:
case <-pc.partitionCtx.Done():
return
default:
}

var e KafkaJSONEvent
err := json.Unmarshal(record.Value, &e)
if err != nil {
pc.logger.Log(centrifuge.NewLogEntry(centrifuge.LogLevelError, "error unmarshalling event from Kafka", map[string]any{"error": err.Error(), "topic": record.Topic, "partition": record.Partition}))
pc.logger.Log(centrifuge.NewLogEntry(centrifuge.LogLevelError, "error unmarshalling record value from Kafka", map[string]any{"error": err.Error(), "topic": record.Topic, "partition": record.Partition}))
pc.cl.MarkCommitRecords(record)
continue
}

var backoffDuration time.Duration = 0
retries := 0
for {
err := pc.dispatcher.Dispatch(pc.clientCtx, e.Method, e.Payload)
err := pc.dispatcher.Dispatch(pc.partitionCtx, e.Method, e.Payload)
if err == nil {
if retries > 0 {
pc.logger.Log(centrifuge.NewLogEntry(centrifuge.LogLevelInfo, "OK processing events after errors", map[string]any{}))
pc.logger.Log(centrifuge.NewLogEntry(centrifuge.LogLevelInfo, "OK processing record after errors", map[string]any{}))
}
pc.cl.MarkCommitRecords(record)
break
}
retries++
backoffDuration = getNextBackoffDuration(backoffDuration, retries)
pc.logger.Log(centrifuge.NewLogEntry(centrifuge.LogLevelError, "error processing consumed event", map[string]any{"error": err.Error(), "method": e.Method, "nextAttemptIn": backoffDuration.String()}))
pc.logger.Log(centrifuge.NewLogEntry(centrifuge.LogLevelError, "error processing consumed record", map[string]any{"error": err.Error(), "method": e.Method, "next_attempt_in": backoffDuration.String()}))
select {
case <-time.After(backoffDuration):
case <-pc.quit:
return
case <-pc.clientCtx.Done():
case <-pc.partitionCtx.Done():
return
}
}
Expand All @@ -471,9 +497,7 @@ func (pc *partitionConsumer) consume() {
defer resumeConsuming()
for {
select {
case <-pc.clientCtx.Done():
return
case <-pc.quit:
case <-pc.partitionCtx.Done():
return
case p := <-pc.recs:
pc.processRecords(p.Records)
Expand Down
187 changes: 183 additions & 4 deletions internal/consuming/kafka_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@ import (
"encoding/json"
"errors"
"fmt"
"strconv"
"strings"
"sync/atomic"
"testing"
"time"

Expand Down Expand Up @@ -44,15 +46,26 @@ func (m *MockLogger) Log(_ centrifuge.LogEntry) {
// Implement mock logic, e.g., storing log entries for assertions
}

func produceManyRecords(records ...*kgo.Record) error {
client, err := kgo.NewClient(kgo.SeedBrokers(testKafkaBrokerURL))
if err != nil {
return fmt.Errorf("failed to create Kafka client: %w", err)
}
defer client.Close()
err = client.ProduceSync(context.Background(), records...).FirstErr()
if err != nil {
return fmt.Errorf("failed to produce message: %w", err)
}
return nil
}

func produceTestMessage(topic string, message []byte) error {
// Create a new client
client, err := kgo.NewClient(kgo.SeedBrokers(testKafkaBrokerURL))
if err != nil {
return fmt.Errorf("failed to create Kafka client: %w", err)
}
defer client.Close()

// Produce a message
err = client.ProduceSync(context.Background(), &kgo.Record{Topic: topic, Partition: 0, Value: message}).FirstErr()
if err != nil {
return fmt.Errorf("failed to produce message: %w", err)
Expand All @@ -61,7 +74,6 @@ func produceTestMessage(topic string, message []byte) error {
}

func produceTestMessageToPartition(topic string, message []byte, partition int32) error {
// Create a new client.
client, err := kgo.NewClient(
kgo.SeedBrokers(testKafkaBrokerURL),
kgo.RecordPartitioner(kgo.ManualPartitioner()),
Expand All @@ -71,7 +83,6 @@ func produceTestMessageToPartition(topic string, message []byte, partition int32
}
defer client.Close()

// Produce a message until we hit desired partition.
res := client.ProduceSync(context.Background(), &kgo.Record{
Topic: topic, Partition: partition, Value: message})
if res.FirstErr() != nil {
Expand Down Expand Up @@ -427,3 +438,171 @@ func TestKafkaConsumer_BlockedPartitionDoesNotBlockAnotherPartition(t *testing.T
})
}
}

func TestKafkaConsumer_PausePartitions(t *testing.T) {
t.Parallel()
testKafkaTopic := "consumer_test_" + uuid.New().String()
testPayload1 := []byte(`{"key":"value1"}`)
testPayload2 := []byte(`{"key":"value2"}`)
testPayload3 := []byte(`{"key":"value3"}`)

ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()

err := createTestTopic(ctx, testKafkaTopic, 1, 1)
require.NoError(t, err)

event1Received := make(chan struct{})
event2Received := make(chan struct{})
event3Received := make(chan struct{})
consumerClosed := make(chan struct{})
doneCh := make(chan struct{})

config := KafkaConfig{
Brokers: []string{testKafkaBrokerURL},
Topics: []string{testKafkaTopic},
ConsumerGroup: uuid.New().String(),
PartitionBufferSize: -1,
}

numCalls := 0

mockDispatcher := &MockDispatcher{
onDispatch: func(ctx context.Context, method string, data []byte) error {
numCalls++
if numCalls == 1 {
close(event1Received)
time.Sleep(5 * time.Second)
return nil
} else if numCalls == 2 {
close(event2Received)
return nil
}
close(event3Received)
return nil
},
}
consumer, err := NewKafkaConsumer("test", uuid.NewString(), &MockLogger{}, mockDispatcher, config)
require.NoError(t, err)

go func() {
err = produceTestMessage(testKafkaTopic, testPayload1)
require.NoError(t, err)
<-event1Received
// At this point message 1 is being processed and the next produced message will
// cause a partition pause.
err = produceTestMessage(testKafkaTopic, testPayload2)
require.NoError(t, err)
<-event2Received
err = produceTestMessage(testKafkaTopic, testPayload3)
require.NoError(t, err)
}()

go func() {
err := consumer.Run(ctx)
require.ErrorIs(t, err, context.Canceled)
close(consumerClosed)
}()

waitCh(t, event1Received, 30*time.Second, "timeout waiting for event 1")
waitCh(t, event2Received, 30*time.Second, "timeout waiting for event 2")
waitCh(t, event3Received, 30*time.Second, "timeout waiting for event 3")
cancel()
waitCh(t, consumerClosed, 30*time.Second, "timeout waiting for consumer closed")
close(doneCh)
}

func TestKafkaConsumer_WorksCorrectlyInLoadedTopic(t *testing.T) {
t.Skip()
t.Parallel()

testCases := []struct {
numPartitions int32
numMessages int
partitionBuffer int
}{
//{numPartitions: 1, numMessages: 1000, partitionBuffer: -1},
//{numPartitions: 1, numMessages: 1000, partitionBuffer: 1},
//{numPartitions: 10, numMessages: 10000, partitionBuffer: -1},
{numPartitions: 10, numMessages: 10000, partitionBuffer: 1},
}

for _, tc := range testCases {
name := fmt.Sprintf("partitions=%d,messages=%d,buffer=%d", tc.numPartitions, tc.numMessages, tc.partitionBuffer)
t.Run(name, func(t *testing.T) {
testKafkaTopic := "consumer_test_" + uuid.New().String()

ctx, cancel := context.WithTimeout(context.Background(), 300*time.Second)
defer cancel()

err := createTestTopic(ctx, testKafkaTopic, tc.numPartitions, 1)
require.NoError(t, err)

consumerClosed := make(chan struct{})
doneCh := make(chan struct{})

numMessages := tc.numMessages
messageCh := make(chan struct{}, numMessages)

mockDispatcher := &MockDispatcher{
onDispatch: func(ctx context.Context, method string, data []byte) error {
// Emulate delay due to some work.
time.Sleep(20 * time.Millisecond)
messageCh <- struct{}{}
return nil
},
}
config := KafkaConfig{
Brokers: []string{testKafkaBrokerURL},
Topics: []string{testKafkaTopic},
ConsumerGroup: uuid.New().String(),
PartitionBufferSize: tc.partitionBuffer,
}
consumer, err := NewKafkaConsumer("test", uuid.NewString(), &MockLogger{}, mockDispatcher, config)
require.NoError(t, err)

var records []*kgo.Record
for i := 0; i < numMessages; i++ {
records = append(records, &kgo.Record{Topic: testKafkaTopic, Value: []byte(`{"hello": "` + strconv.Itoa(i) + `"}`)})
if (i+1)%100 == 0 {
err = produceManyRecords(records...)
if err != nil {
t.Fatal(err)
}
records = nil
t.Logf("produced %d messages", i+1)
}
}

t.Logf("all messages produced, 3, 2, 1, go!")
time.Sleep(time.Second)

go func() {
err := consumer.Run(ctx)
require.ErrorIs(t, err, context.Canceled)
close(consumerClosed)
}()

var numProcessed int64
go func() {
for {
select {
case <-ctx.Done():
return
case <-time.After(time.Second):
t.Logf("processed %d messages", atomic.LoadInt64(&numProcessed))
}
}
}()

for i := 0; i < numMessages; i++ {
<-messageCh
atomic.AddInt64(&numProcessed, 1)
}
t.Logf("all messages processed")
cancel()
waitCh(t, consumerClosed, 30*time.Second, "timeout waiting for consumer closed")
close(doneCh)
})
}
}

0 comments on commit 8bcb850

Please sign in to comment.