diff --git a/pkg/querier/ingester_querier.go b/pkg/querier/ingester_querier.go index 830ef53961ad3..76c8b8bfc42db 100644 --- a/pkg/querier/ingester_querier.go +++ b/pkg/querier/ingester_querier.go @@ -5,6 +5,7 @@ import ( "net/http" "slices" "strings" + "sync" "time" "github.com/go-kit/log" @@ -82,10 +83,91 @@ func newIngesterQuerier(querierConfig Config, clientCfg client.Config, ring ring return &iq, nil } +type ctxKeyType string + +const ( + partitionCtxKey ctxKeyType = "partitionCtx" +) + +type PartitionContext struct { + isPartitioned bool + ingestersUsed map[string]PartitionIngesterUsed + mtx sync.Mutex +} + +type PartitionIngesterUsed struct { + client logproto.QuerierClient + addr string +} + +func (p *PartitionContext) AddClient(client logproto.QuerierClient, addr string) { + p.mtx.Lock() + defer p.mtx.Unlock() + if !p.isPartitioned { + return + } + p.ingestersUsed[addr] = PartitionIngesterUsed{client: client, addr: addr} +} + +func (p *PartitionContext) RemoveClient(addr string) { + p.mtx.Lock() + defer p.mtx.Unlock() + if !p.isPartitioned { + return + } + delete(p.ingestersUsed, addr) +} + +func (p *PartitionContext) SetIsPartitioned(isPartitioned bool) { + p.mtx.Lock() + defer p.mtx.Unlock() + p.isPartitioned = isPartitioned +} + +func (p *PartitionContext) IsPartitioned() bool { + return p.isPartitioned +} + +func (p *PartitionContext) forQueriedIngesters(ctx context.Context, f func(context.Context, logproto.QuerierClient) (interface{}, error)) ([]responseFromIngesters, error) { + p.mtx.Lock() + defer p.mtx.Unlock() + + ingestersUsed := make([]PartitionIngesterUsed, 0, len(p.ingestersUsed)) + for _, ingester := range p.ingestersUsed { + ingestersUsed = append(ingestersUsed, ingester) + } + + return concurrency.ForEachJobMergeResults(ctx, ingestersUsed, 0, func(ctx context.Context, job PartitionIngesterUsed) ([]responseFromIngesters, error) { + resp, err := f(ctx, job.client) + if err != nil { + return nil, err + } + return []responseFromIngesters{{addr: job.addr, response: resp}}, nil + }) +} + +// NewPartitionContext creates a new partition context +// This is used to track which ingesters were used in the query and reuse the same ingesters for consecutive queries +func NewPartitionContext(ctx context.Context) context.Context { + return context.WithValue(ctx, partitionCtxKey, &PartitionContext{ + ingestersUsed: make(map[string]PartitionIngesterUsed), + }) +} + +func ExtractPartitionContext(ctx context.Context) *PartitionContext { + v, ok := ctx.Value(partitionCtxKey).(*PartitionContext) + if !ok { + return &PartitionContext{ + ingestersUsed: make(map[string]PartitionIngesterUsed), + } + } + return v +} + // forAllIngesters runs f, in parallel, for all ingesters -// waitForAllResponses param can be used to require results from all ingesters in the replication set. If this is set to false, the call will return as soon as we have a quorum by zone. Only valid for partition-ingesters. -func (q *IngesterQuerier) forAllIngesters(ctx context.Context, waitForAllResponses bool, f func(context.Context, logproto.QuerierClient) (interface{}, error)) ([]responseFromIngesters, error) { +func (q *IngesterQuerier) forAllIngesters(ctx context.Context, f func(context.Context, logproto.QuerierClient) (interface{}, error)) ([]responseFromIngesters, error) { if q.querierConfig.QueryPartitionIngesters { + ExtractPartitionContext(ctx).SetIsPartitioned(true) tenantID, err := user.ExtractOrgID(ctx) if err != nil { return nil, err @@ -99,7 +181,7 @@ func (q *IngesterQuerier) forAllIngesters(ctx context.Context, waitForAllRespons if err != nil { return nil, err } - return q.forGivenIngesterSets(ctx, waitForAllResponses, replicationSets, f) + return q.forGivenIngesterSets(ctx, replicationSets, f) } replicationSet, err := q.ring.GetReplicationSetForOperation(ring.Read) @@ -111,19 +193,13 @@ func (q *IngesterQuerier) forAllIngesters(ctx context.Context, waitForAllRespons } // forGivenIngesterSets runs f, in parallel, for given ingester sets -// waitForAllResponses param can be used to require results from all ingesters in all replication sets. If this is set to false, the call will return as soon as we have a quorum by zone. -func (q *IngesterQuerier) forGivenIngesterSets(ctx context.Context, waitForAllResponses bool, replicationSet []ring.ReplicationSet, f func(context.Context, logproto.QuerierClient) (interface{}, error)) ([]responseFromIngesters, error) { +func (q *IngesterQuerier) forGivenIngesterSets(ctx context.Context, replicationSet []ring.ReplicationSet, f func(context.Context, logproto.QuerierClient) (interface{}, error)) ([]responseFromIngesters, error) { // Enable minimize requests if we can, so we initially query a single ingester per replication set, as each replication-set is one partition. // Ingesters must supply zone information for this to have an effect. config := ring.DoUntilQuorumConfig{ - MinimizeRequests: !waitForAllResponses, + MinimizeRequests: true, } return concurrency.ForEachJobMergeResults[ring.ReplicationSet, responseFromIngesters](ctx, replicationSet, 0, func(ctx context.Context, set ring.ReplicationSet) ([]responseFromIngesters, error) { - if waitForAllResponses { - // Tell the ring we need to return all responses from all zones - set.MaxErrors = 0 - set.MaxUnavailableZones = 0 - } return q.forGivenIngesters(ctx, set, config, f) }) } @@ -135,17 +211,16 @@ func (q *IngesterQuerier) forGivenIngesters(ctx context.Context, replicationSet if err != nil { return responseFromIngesters{addr: ingester.Addr}, err } - resp, err := f(ctx, client.(logproto.QuerierClient)) if err != nil { return responseFromIngesters{addr: ingester.Addr}, err } + ExtractPartitionContext(ctx).AddClient(client.(logproto.QuerierClient), ingester.Addr) return responseFromIngesters{ingester.Addr, resp}, nil - }, func(responseFromIngesters) { - // Nothing to do + }, func(cleanup responseFromIngesters) { + ExtractPartitionContext(ctx).RemoveClient(cleanup.addr) }) - if err != nil { return nil, err } @@ -157,7 +232,7 @@ func (q *IngesterQuerier) forGivenIngesters(ctx context.Context, replicationSet } func (q *IngesterQuerier) SelectLogs(ctx context.Context, params logql.SelectLogParams) ([]iter.EntryIterator, error) { - resps, err := q.forAllIngesters(ctx, false, func(_ context.Context, client logproto.QuerierClient) (interface{}, error) { + resps, err := q.forAllIngesters(ctx, func(_ context.Context, client logproto.QuerierClient) (interface{}, error) { stats.FromContext(ctx).AddIngesterReached(1) return client.Query(ctx, params.QueryRequest) }) @@ -173,7 +248,7 @@ func (q *IngesterQuerier) SelectLogs(ctx context.Context, params logql.SelectLog } func (q *IngesterQuerier) SelectSample(ctx context.Context, params logql.SelectSampleParams) ([]iter.SampleIterator, error) { - resps, err := q.forAllIngesters(ctx, false, func(_ context.Context, client logproto.QuerierClient) (interface{}, error) { + resps, err := q.forAllIngesters(ctx, func(_ context.Context, client logproto.QuerierClient) (interface{}, error) { stats.FromContext(ctx).AddIngesterReached(1) return client.QuerySample(ctx, params.SampleQueryRequest) }) @@ -189,7 +264,7 @@ func (q *IngesterQuerier) SelectSample(ctx context.Context, params logql.SelectS } func (q *IngesterQuerier) Label(ctx context.Context, req *logproto.LabelRequest) ([][]string, error) { - resps, err := q.forAllIngesters(ctx, false, func(ctx context.Context, client logproto.QuerierClient) (interface{}, error) { + resps, err := q.forAllIngesters(ctx, func(ctx context.Context, client logproto.QuerierClient) (interface{}, error) { return client.Label(ctx, req) }) if err != nil { @@ -205,7 +280,7 @@ func (q *IngesterQuerier) Label(ctx context.Context, req *logproto.LabelRequest) } func (q *IngesterQuerier) Tail(ctx context.Context, req *logproto.TailRequest) (map[string]logproto.Querier_TailClient, error) { - resps, err := q.forAllIngesters(ctx, false, func(_ context.Context, client logproto.QuerierClient) (interface{}, error) { + resps, err := q.forAllIngesters(ctx, func(_ context.Context, client logproto.QuerierClient) (interface{}, error) { return client.Tail(ctx, req) }) if err != nil { @@ -270,7 +345,7 @@ func (q *IngesterQuerier) TailDisconnectedIngesters(ctx context.Context, req *lo } func (q *IngesterQuerier) Series(ctx context.Context, req *logproto.SeriesRequest) ([][]logproto.SeriesIdentifier, error) { - resps, err := q.forAllIngesters(ctx, false, func(ctx context.Context, client logproto.QuerierClient) (interface{}, error) { + resps, err := q.forAllIngesters(ctx, func(ctx context.Context, client logproto.QuerierClient) (interface{}, error) { return client.Series(ctx, req) }) if err != nil { @@ -325,15 +400,22 @@ func (q *IngesterQuerier) TailersCount(ctx context.Context) ([]uint32, error) { } func (q *IngesterQuerier) GetChunkIDs(ctx context.Context, from, through model.Time, matchers ...*labels.Matcher) ([]string, error) { - // We must wait for all responses when using partition-ingesters to avoid a race between Query and GetChunkIDs calls. - // This occurs if call Query on an ingester after a recent flush then call GetChunkIDs on a different, unflushed ingester in the same partition. - resps, err := q.forAllIngesters(ctx, q.querierConfig.QueryPartitionIngesters, func(ctx context.Context, querierClient logproto.QuerierClient) (interface{}, error) { + ingesterQueryFn := q.forAllIngesters + + partitionCtx := ExtractPartitionContext(ctx) + if partitionCtx.IsPartitioned() { + // We need to query the same ingesters as the previous query + ingesterQueryFn = partitionCtx.forQueriedIngesters + } + + resps, err := ingesterQueryFn(ctx, func(ctx context.Context, querierClient logproto.QuerierClient) (interface{}, error) { return querierClient.GetChunkIDs(ctx, &logproto.GetChunkIDsRequest{ Matchers: convertMatchersToString(matchers), Start: from.Time(), End: through.Time(), }) }) + if err != nil { return nil, err } @@ -347,14 +429,13 @@ func (q *IngesterQuerier) GetChunkIDs(ctx context.Context, from, through model.T } func (q *IngesterQuerier) Stats(ctx context.Context, _ string, from, through model.Time, matchers ...*labels.Matcher) (*index_stats.Stats, error) { - resps, err := q.forAllIngesters(ctx, false, func(ctx context.Context, querierClient logproto.QuerierClient) (interface{}, error) { + resps, err := q.forAllIngesters(ctx, func(ctx context.Context, querierClient logproto.QuerierClient) (interface{}, error) { return querierClient.GetStats(ctx, &logproto.IndexStatsRequest{ From: from, Through: through, Matchers: syntax.MatchersString(matchers), }) }) - if err != nil { if isUnimplementedCallError(err) { // Handle communication with older ingesters gracefully @@ -378,7 +459,7 @@ func (q *IngesterQuerier) Volume(ctx context.Context, _ string, from, through mo matcherString = syntax.MatchersString(matchers) } - resps, err := q.forAllIngesters(ctx, false, func(ctx context.Context, querierClient logproto.QuerierClient) (interface{}, error) { + resps, err := q.forAllIngesters(ctx, func(ctx context.Context, querierClient logproto.QuerierClient) (interface{}, error) { return querierClient.GetVolume(ctx, &logproto.VolumeRequest{ From: from, Through: through, @@ -388,7 +469,6 @@ func (q *IngesterQuerier) Volume(ctx context.Context, _ string, from, through mo AggregateBy: aggregateBy, }) }) - if err != nil { if isUnimplementedCallError(err) { // Handle communication with older ingesters gracefully @@ -407,10 +487,9 @@ func (q *IngesterQuerier) Volume(ctx context.Context, _ string, from, through mo } func (q *IngesterQuerier) DetectedLabel(ctx context.Context, req *logproto.DetectedLabelsRequest) (*logproto.LabelToValuesResponse, error) { - ingesterResponses, err := q.forAllIngesters(ctx, false, func(ctx context.Context, client logproto.QuerierClient) (interface{}, error) { + ingesterResponses, err := q.forAllIngesters(ctx, func(ctx context.Context, client logproto.QuerierClient) (interface{}, error) { return client.GetDetectedLabels(ctx, req) }) - if err != nil { level.Error(q.logger).Log("msg", "error getting detected labels", "err", err) return nil, err diff --git a/pkg/querier/ingester_querier_test.go b/pkg/querier/ingester_querier_test.go index a5066176ad785..268191bd17a72 100644 --- a/pkg/querier/ingester_querier_test.go +++ b/pkg/querier/ingester_querier_test.go @@ -8,6 +8,7 @@ import ( "time" "github.com/go-kit/log" + "github.com/grafana/dskit/ring/client" "github.com/grafana/dskit/user" "go.uber.org/atomic" @@ -241,11 +242,10 @@ func TestIngesterQuerierFetchesResponsesFromPartitionIngesters(t *testing.T) { } tests := map[string]struct { - method string - testFn func(*IngesterQuerier) error - retVal interface{} - shards int - expectAllResponses bool + method string + testFn func(*IngesterQuerier) error + retVal interface{} + shards int }{ "label": { method: "Label", @@ -269,8 +269,7 @@ func TestIngesterQuerierFetchesResponsesFromPartitionIngesters(t *testing.T) { _, err := ingesterQuerier.GetChunkIDs(ctx, model.Time(0), model.Time(0)) return err }, - retVal: new(logproto.GetChunkIDsResponse), - expectAllResponses: true, + retVal: new(logproto.GetChunkIDsResponse), }, "select_logs": { method: "Query", @@ -330,7 +329,7 @@ func TestIngesterQuerierFetchesResponsesFromPartitionIngesters(t *testing.T) { ingestersPerPartition := len(ingesters) / partitions assert.Greaterf(t, ingestersPerPartition, 1, "must have more than one ingester per partition") - ingesterQuerier, err := newTestPartitionIngesterQuerier(ingesterClient, instanceRing, newPartitionInstanceRingMock(instanceRing, ingesters, partitions, ingestersPerPartition), testData.shards) + ingesterQuerier, err := newTestPartitionIngesterQuerier(newIngesterClientMockFactory(ingesterClient), instanceRing, newPartitionInstanceRingMock(instanceRing, ingesters, partitions, ingestersPerPartition), testData.shards) require.NoError(t, err) ingesterQuerier.querierConfig.QueryPartitionIngesters = true @@ -342,9 +341,6 @@ func TestIngesterQuerierFetchesResponsesFromPartitionIngesters(t *testing.T) { testData.shards = partitions } expectedCalls := min(testData.shards, partitions) - if testData.expectAllResponses { - expectedCalls = expectedCalls * ingestersPerPartition - } // Wait for responses: We expect one request per queried partition because we have request minimization enabled & ingesters are in multiple zones. // If shuffle sharding is enabled, we expect one query per shard as we write to a subset of partitions. require.Eventually(t, func() bool { return cnt.Load() >= int32(expectedCalls) }, time.Millisecond*100, time.Millisecond*1, "expected all ingesters to respond") @@ -353,6 +349,137 @@ func TestIngesterQuerierFetchesResponsesFromPartitionIngesters(t *testing.T) { } } +func TestIngesterQuerier_QueriesSameIngestersWithPartitionContext(t *testing.T) { + t.Parallel() + userCtx := user.InjectOrgID(context.Background(), "test-user") + testCtx, cancel := context.WithTimeout(userCtx, time.Second*10) + defer cancel() + + ingesters := []ring.InstanceDesc{ + mockInstanceDescWithZone("1.1.1.1", ring.ACTIVE, "A"), + mockInstanceDescWithZone("2.2.2.2", ring.ACTIVE, "B"), + mockInstanceDescWithZone("3.3.3.3", ring.ACTIVE, "A"), + mockInstanceDescWithZone("4.4.4.4", ring.ACTIVE, "B"), + mockInstanceDescWithZone("5.5.5.5", ring.ACTIVE, "A"), + mockInstanceDescWithZone("6.6.6.6", ring.ACTIVE, "B"), + } + + tests := map[string]struct { + method string + testFn func(context.Context, *IngesterQuerier) error + retVal interface{} + shards int + }{ + "select_logs": { + method: "Query", + testFn: func(ctx context.Context, ingesterQuerier *IngesterQuerier) error { + _, err := ingesterQuerier.SelectLogs(ctx, logql.SelectLogParams{ + QueryRequest: new(logproto.QueryRequest), + }) + return err + }, + retVal: newQueryClientMock(), + }, + "select_sample": { + method: "QuerySample", + testFn: func(ctx context.Context, ingesterQuerier *IngesterQuerier) error { + _, err := ingesterQuerier.SelectSample(ctx, logql.SelectSampleParams{ + SampleQueryRequest: new(logproto.SampleQueryRequest), + }) + return err + }, + retVal: newQuerySampleClientMock(), + }, + "select_logs_shuffle_sharded": { + method: "Query", + testFn: func(ctx context.Context, ingesterQuerier *IngesterQuerier) error { + _, err := ingesterQuerier.SelectLogs(ctx, logql.SelectLogParams{ + QueryRequest: new(logproto.QueryRequest), + }) + return err + }, + retVal: newQueryClientMock(), + shards: 2, // Must be less than number of partitions + }, + } + + for testName, testData := range tests { + cnt := atomic.NewInt32(0) + ctx := NewPartitionContext(testCtx) + + t.Run(testName, func(t *testing.T) { + cnt.Store(0) + runFn := func(args mock.Arguments) { + ctx := args[0].(context.Context) + + select { + case <-ctx.Done(): + // should not be cancelled by the tracker + require.NoErrorf(t, ctx.Err(), "tracker should not cancel ctx: %v", context.Cause(ctx)) + default: + cnt.Add(1) + } + } + + instanceRing := newReadRingMock(ingesters, 0) + ingesterClient := newQuerierClientMock() + ingesterClient.On(testData.method, mock.Anything, mock.Anything, mock.Anything).Return(testData.retVal, nil).Run(runFn) + ingesterClient.On("GetChunkIDs", mock.Anything, mock.Anything, mock.Anything).Return(new(logproto.GetChunkIDsResponse), nil).Run(runFn) + + partitions := 3 + ingestersPerPartition := len(ingesters) / partitions + assert.Greaterf(t, ingestersPerPartition, 1, "must have more than one ingester per partition") + + mockClientFactory := mockIngesterClientFactory{ + requestedClients: make(map[string]int), + } + + ingesterQuerier, err := newTestPartitionIngesterQuerier(mockClientFactory.newIngesterClientMockFactory(ingesterClient), instanceRing, newPartitionInstanceRingMock(instanceRing, ingesters, partitions, ingestersPerPartition), testData.shards) + require.NoError(t, err) + + ingesterQuerier.querierConfig.QueryPartitionIngesters = true + + err = testData.testFn(ctx, ingesterQuerier) + require.NoError(t, err) + + if testData.shards == 0 { + testData.shards = partitions + } + expectedCalls := min(testData.shards, partitions) + expectedIngesterCalls := expectedCalls + // Wait for responses: We expect one request per queried partition because we have request minimization enabled & ingesters are in multiple zones. + // If shuffle sharding is enabled, we expect one query per shard as we write to a subset of partitions. + require.Eventually(t, func() bool { return cnt.Load() >= int32(expectedCalls) }, time.Millisecond*100, time.Millisecond*1, "expected ingesters to respond") + ingesterClient.AssertNumberOfCalls(t, testData.method, expectedCalls) + + partitionCtx := ExtractPartitionContext(ctx) + require.Equal(t, expectedIngesterCalls, len(partitionCtx.ingestersUsed)) + require.Equal(t, expectedIngesterCalls, len(mockClientFactory.requestedClients)) + + for _, ingester := range partitionCtx.ingestersUsed { + count, ok := mockClientFactory.requestedClients[ingester.addr] + require.True(t, ok) + require.Equal(t, count, 1) + } + + // Now call getChunkIDs to ensure we only call the same ingesters as before. + _, err = ingesterQuerier.GetChunkIDs(ctx, model.Time(0), model.Time(1)) + require.NoError(t, err) + + require.Eventually(t, func() bool { return cnt.Load() >= int32(expectedCalls) }, time.Millisecond*100, time.Millisecond*1, "expected ingesters to respond") + ingesterClient.AssertNumberOfCalls(t, "GetChunkIDs", expectedCalls) + + // Finally, confirm we called the same ingesters again and didn't ask for any new clients + require.Equal(t, expectedIngesterCalls, len(mockClientFactory.requestedClients)) + for _, ingester := range partitionCtx.ingestersUsed { + count, ok := mockClientFactory.requestedClients[ingester.addr] + require.True(t, ok) + require.Equal(t, count, 1) + } + }) + } +} + func TestQuerier_tailDisconnectedIngesters(t *testing.T) { t.Parallel() @@ -540,14 +667,14 @@ func newTestIngesterQuerier(readRingMock *readRingMock, ingesterClient *querierC ) } -func newTestPartitionIngesterQuerier(ingesterClient *querierClientMock, instanceRing *readRingMock, partitionRing *ring.PartitionInstanceRing, tenantShards int) (*IngesterQuerier, error) { +func newTestPartitionIngesterQuerier(clientFactory client.PoolFactory, instanceRing *readRingMock, partitionRing *ring.PartitionInstanceRing, tenantShards int) (*IngesterQuerier, error) { return newIngesterQuerier( mockQuerierConfig(), mockIngesterClientConfig(), instanceRing, partitionRing, func(string) int { return tenantShards }, - newIngesterClientMockFactory(ingesterClient), + clientFactory, constants.Loki, log.NewNopLogger(), ) diff --git a/pkg/querier/querier.go b/pkg/querier/querier.go index 947101c0aa2ff..3f03d1e037aad 100644 --- a/pkg/querier/querier.go +++ b/pkg/querier/querier.go @@ -152,6 +152,9 @@ func New(cfg Config, store Store, ingesterQuerier *IngesterQuerier, limits Limit // Select Implements logql.Querier which select logs via matchers and regex filters. func (q *SingleTenantQuerier) SelectLogs(ctx context.Context, params logql.SelectLogParams) (iter.EntryIterator, error) { + // Create a new partition context for the query + // This is used to track which ingesters were used in the query and reuse the same ingesters for consecutive queries + ctx = NewPartitionContext(ctx) var err error params.Start, params.End, err = q.validateQueryRequest(ctx, params) if err != nil { @@ -211,6 +214,9 @@ func (q *SingleTenantQuerier) SelectLogs(ctx context.Context, params logql.Selec } func (q *SingleTenantQuerier) SelectSamples(ctx context.Context, params logql.SelectSampleParams) (iter.SampleIterator, error) { + // Create a new partition context for the query + // This is used to track which ingesters were used in the query and reuse the same ingesters for consecutive queries + ctx = NewPartitionContext(ctx) var err error params.Start, params.End, err = q.validateQueryRequest(ctx, params) if err != nil { diff --git a/pkg/querier/querier_mock_test.go b/pkg/querier/querier_mock_test.go index ab70de4baacea..0fd9b421de000 100644 --- a/pkg/querier/querier_mock_test.go +++ b/pkg/querier/querier_mock_test.go @@ -142,6 +142,19 @@ func (c *querierClientMock) Close() error { return nil } +type mockIngesterClientFactory struct { + requestedClients map[string]int +} + +// newIngesterClientMockFactory creates a factory function always returning +// the input querierClientMock +func (f mockIngesterClientFactory) newIngesterClientMockFactory(c *querierClientMock) ring_client.PoolFactory { + return ring_client.PoolAddrFunc(func(addr string) (ring_client.PoolClient, error) { + f.requestedClients[addr]++ + return c, nil + }) +} + // newIngesterClientMockFactory creates a factory function always returning // the input querierClientMock func newIngesterClientMockFactory(c *querierClientMock) ring_client.PoolFactory {