Skip to content

Commit

Permalink
Bloom Gateway: Make tasks cancelable (#11792)
Browse files Browse the repository at this point in the history
This PR refactors the bloom gateway workers so that tasks that have been
enqueued by requests do not end up locking the results channel and
therefore the worker, in case the request was cancelled (`context
cancelled`) or timed out (`context deadline exceeded`).
It also handles errors from the shipper in a way that they are returned
to the waiting request asap so it can return and does not need to wait
for all tasks to finish.

This PR also fixes the worker shutdown in a way that it now gracefully
stops and continues to work off the remaining tasks from the queue.

---------

Signed-off-by: Christian Haudum <[email protected]>
  • Loading branch information
chaudum authored Jan 29, 2024
1 parent c01a823 commit e0e143a
Show file tree
Hide file tree
Showing 5 changed files with 352 additions and 159 deletions.
103 changes: 57 additions & 46 deletions pkg/bloomgateway/bloomgateway.go
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,8 @@ func (g *Gateway) FilterChunkRefs(ctx context.Context, req *logproto.FilterChunk
return nil, err
}

logger := log.With(g.logger, "tenant", tenantID)

// start time == end time --> empty response
if req.From.Equal(req.Through) {
return &logproto.FilterChunkRefResponse{
Expand Down Expand Up @@ -327,79 +329,60 @@ func (g *Gateway) FilterChunkRefs(ctx context.Context, req *logproto.FilterChunk
return req.Refs[i].Fingerprint < req.Refs[j].Fingerprint
})

var expectedResponses int
seriesWithBloomsPerDay := partitionRequest(req)
var numSeries int
seriesByDay := partitionRequest(req)

// no tasks --> empty response
if len(seriesWithBloomsPerDay) == 0 {
if len(seriesByDay) == 0 {
return &logproto.FilterChunkRefResponse{
ChunkRefs: []*logproto.GroupedChunkRefs{},
}, nil
}

tasks := make([]Task, 0, len(seriesWithBloomsPerDay))
for _, seriesWithBounds := range seriesWithBloomsPerDay {
task, err := NewTask(tenantID, seriesWithBounds, req.Filters)
tasks := make([]Task, 0, len(seriesByDay))
for _, seriesWithBounds := range seriesByDay {
task, err := NewTask(ctx, tenantID, seriesWithBounds, req.Filters)
if err != nil {
return nil, err
}
tasks = append(tasks, task)
expectedResponses += len(seriesWithBounds.series)
numSeries += len(seriesWithBounds.series)
}

g.activeUsers.UpdateUserTimestamp(tenantID, time.Now())

errCh := make(chan error, 1)
resCh := make(chan v1.Output, 1)

// Ideally we could use an unbuffered channel here, but since we return the
// request on the first error, there can be cases where the request context
// is not done yet and the consumeTask() function wants to send to the
// tasksCh, but nobody reads from it any more.
tasksCh := make(chan Task, len(tasks))
for _, task := range tasks {
level.Info(g.logger).Log("msg", "enqueue task", "task", task.ID, "day", task.day, "series", len(task.series))
task := task
level.Info(logger).Log("msg", "enqueue task", "task", task.ID, "day", task.day, "series", len(task.series))
g.queue.Enqueue(tenantID, []string{}, task, func() {
// When enqueuing, we also add the task to the pending tasks
g.pendingTasks.Add(task.ID, task)
})

// Forward responses or error to the main channels
// TODO(chaudum): Refactor to make tasks cancelable
go func(t Task) {
for {
select {
case <-ctx.Done():
return
case err := <-t.ErrCh:
if ctx.Err() != nil {
level.Warn(g.logger).Log("msg", "received err from channel, but context is already done", "err", ctx.Err())
return
}
errCh <- err
case res := <-t.ResCh:
level.Debug(g.logger).Log("msg", "got partial result", "task", t.ID, "tenant", tenantID, "fp_int", uint64(res.Fp), "fp_hex", res.Fp, "chunks_to_remove", res.Removals.Len())
if ctx.Err() != nil {
level.Warn(g.logger).Log("msg", "received res from channel, but context is already done", "err", ctx.Err())
return
}
resCh <- res
}
}
}(task)
go consumeTask(ctx, task, tasksCh, logger)
}

responses := responsesPool.Get(expectedResponses)
responses := responsesPool.Get(numSeries)
defer responsesPool.Put(responses)
remaining := len(tasks)

outer:
for {
select {
case <-ctx.Done():
return nil, errors.Wrap(ctx.Err(), "waiting for results")
case err := <-errCh:
return nil, errors.Wrap(err, "waiting for results")
case res := <-resCh:
responses = append(responses, res)
// log line is helpful for debugging tests
level.Debug(g.logger).Log("msg", "got partial result", "progress", fmt.Sprintf("%d/%d", len(responses), expectedResponses))
// wait for all parts of the full response
if len(responses) == expectedResponses {
return nil, errors.Wrap(ctx.Err(), "request failed")
case task := <-tasksCh:
level.Info(logger).Log("msg", "task done", "task", task.ID, "err", task.Err())
if task.Err() != nil {
return nil, errors.Wrap(task.Err(), "request failed")
}
responses = append(responses, task.responses...)
remaining--
if remaining == 0 {
break outer
}
}
Expand All @@ -415,10 +398,38 @@ outer:
g.metrics.addUnfilteredCount(numChunksUnfiltered)
g.metrics.addFilteredCount(len(req.Refs))

level.Debug(g.logger).Log("msg", "return filtered chunk refs", "unfiltered", numChunksUnfiltered, "filtered", len(req.Refs))
level.Info(logger).Log("msg", "return filtered chunk refs", "unfiltered", numChunksUnfiltered, "filtered", len(req.Refs))
return &logproto.FilterChunkRefResponse{ChunkRefs: req.Refs}, nil
}

// consumeTask receives v1.Output yielded from the block querier on the task's
// result channel and stores them on the task.
// In case the context task is done, it drains the remaining items until the
// task is closed by the worker.
// Once the tasks is closed, it will send the task with the results from the
// block querier to the supplied task channel.
func consumeTask(ctx context.Context, task Task, tasksCh chan<- Task, logger log.Logger) {
logger = log.With(logger, "task", task.ID)

for res := range task.resCh {
select {
case <-ctx.Done():
level.Debug(logger).Log("msg", "drop partial result", "fp_int", uint64(res.Fp), "fp_hex", res.Fp, "chunks_to_remove", res.Removals.Len())
default:
level.Debug(logger).Log("msg", "accept partial result", "fp_int", uint64(res.Fp), "fp_hex", res.Fp, "chunks_to_remove", res.Removals.Len())
task.responses = append(task.responses, res)
}
}

select {
case <-ctx.Done():
// do nothing
case <-task.Done():
// notify request handler about finished task
tasksCh <- task
}
}

func removeNotMatchingChunks(req *logproto.FilterChunkRefRequest, res v1.Output, logger log.Logger) {
// binary search index of fingerprint
idx := sort.Search(len(req.Refs), func(i int) bool {
Expand Down
107 changes: 106 additions & 1 deletion pkg/bloomgateway/bloomgateway_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"github.com/grafana/dskit/ring"
"github.com/grafana/dskit/services"
"github.com/grafana/dskit/user"
"github.com/pkg/errors"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/common/model"
"github.com/prometheus/prometheus/model/labels"
Expand Down Expand Up @@ -183,6 +184,96 @@ func TestBloomGateway_FilterChunkRefs(t *testing.T) {
MaxOutstandingPerTenant: 1024,
}

t.Run("shipper error is propagated", func(t *testing.T) {
reg := prometheus.NewRegistry()
gw, err := New(cfg, schemaCfg, storageCfg, limits, ss, cm, logger, reg)
require.NoError(t, err)

now := mktime("2023-10-03 10:00")

bqs, data := createBlockQueriers(t, 10, now.Add(-24*time.Hour), now, 0, 1000)
mockStore := newMockBloomStore(bqs)
mockStore.err = errors.New("failed to fetch block")
gw.bloomShipper = mockStore

err = gw.initServices()
require.NoError(t, err)

err = services.StartAndAwaitRunning(context.Background(), gw)
require.NoError(t, err)
t.Cleanup(func() {
err = services.StopAndAwaitTerminated(context.Background(), gw)
require.NoError(t, err)
})

chunkRefs := createQueryInputFromBlockData(t, tenantID, data, 10)

// saturate workers
// then send additional request
for i := 0; i < gw.cfg.WorkerConcurrency+1; i++ {
req := &logproto.FilterChunkRefRequest{
From: now.Add(-24 * time.Hour),
Through: now,
Refs: groupRefs(t, chunkRefs),
Filters: []syntax.LineFilter{
{Ty: labels.MatchEqual, Match: "does not match"},
},
}

ctx, cancelFn := context.WithTimeout(context.Background(), 10*time.Second)
ctx = user.InjectOrgID(ctx, tenantID)
t.Cleanup(cancelFn)

res, err := gw.FilterChunkRefs(ctx, req)
require.ErrorContainsf(t, err, "request failed: failed to fetch block", "%+v", res)
}
})

t.Run("request cancellation does not result in channel locking", func(t *testing.T) {
reg := prometheus.NewRegistry()
gw, err := New(cfg, schemaCfg, storageCfg, limits, ss, cm, logger, reg)
require.NoError(t, err)

now := mktime("2024-01-25 10:00")

bqs, data := createBlockQueriers(t, 50, now.Add(-24*time.Hour), now, 0, 1024)
mockStore := newMockBloomStore(bqs)
mockStore.delay = 50 * time.Millisecond // delay for each block - 50x50=2500ms
gw.bloomShipper = mockStore

err = gw.initServices()
require.NoError(t, err)

err = services.StartAndAwaitRunning(context.Background(), gw)
require.NoError(t, err)
t.Cleanup(func() {
err = services.StopAndAwaitTerminated(context.Background(), gw)
require.NoError(t, err)
})

chunkRefs := createQueryInputFromBlockData(t, tenantID, data, 100)

// saturate workers
// then send additional request
for i := 0; i < gw.cfg.WorkerConcurrency+1; i++ {
req := &logproto.FilterChunkRefRequest{
From: now.Add(-24 * time.Hour),
Through: now,
Refs: groupRefs(t, chunkRefs),
Filters: []syntax.LineFilter{
{Ty: labels.MatchEqual, Match: "does not match"},
},
}

ctx, cancelFn := context.WithTimeout(context.Background(), 500*time.Millisecond)
ctx = user.InjectOrgID(ctx, tenantID)
t.Cleanup(cancelFn)

res, err := gw.FilterChunkRefs(ctx, req)
require.ErrorContainsf(t, err, context.DeadlineExceeded.Error(), "%+v", res)
}
})

t.Run("returns unfiltered chunk refs if no filters provided", func(t *testing.T) {
reg := prometheus.NewRegistry()
gw, err := New(cfg, schemaCfg, storageCfg, limits, ss, cm, logger, reg)
Expand Down Expand Up @@ -428,12 +519,17 @@ func newMockBloomStore(bqs []bloomshipper.BlockQuerierWithFingerprintRange) *moc

type mockBloomStore struct {
bqs []bloomshipper.BlockQuerierWithFingerprintRange
// mock how long it takes to serve block queriers
delay time.Duration
// mock response error when serving block queriers in ForEach
err error
}

var _ bloomshipper.Interface = &mockBloomStore{}

// GetBlockRefs implements bloomshipper.Interface
func (s *mockBloomStore) GetBlockRefs(_ context.Context, tenant string, _ bloomshipper.Interval) ([]bloomshipper.BlockRef, error) {
time.Sleep(s.delay)
blocks := make([]bloomshipper.BlockRef, 0, len(s.bqs))
for i := range s.bqs {
blocks = append(blocks, bloomshipper.BlockRef{
Expand All @@ -452,6 +548,11 @@ func (s *mockBloomStore) Stop() {}

// Fetch implements bloomshipper.Interface
func (s *mockBloomStore) Fetch(_ context.Context, _ string, _ []bloomshipper.BlockRef, callback bloomshipper.ForEachBlockCallback) error {
if s.err != nil {
time.Sleep(s.delay)
return s.err
}

shuffled := make([]bloomshipper.BlockQuerierWithFingerprintRange, len(s.bqs))
_ = copy(shuffled, s.bqs)

Expand All @@ -461,7 +562,11 @@ func (s *mockBloomStore) Fetch(_ context.Context, _ string, _ []bloomshipper.Blo

for _, bq := range shuffled {
// ignore errors in the mock
_ = callback(bq.BlockQuerier, uint64(bq.MinFp), uint64(bq.MaxFp))
time.Sleep(s.delay)
err := callback(bq.BlockQuerier, uint64(bq.MinFp), uint64(bq.MaxFp))
if err != nil {
return err
}
}
return nil
}
Expand Down
Loading

0 comments on commit e0e143a

Please sign in to comment.