diff --git a/pkg/loki/loki.go b/pkg/loki/loki.go index 0b09ee0491849..718b5fd13c195 100644 --- a/pkg/loki/loki.go +++ b/pkg/loki/loki.go @@ -312,7 +312,7 @@ type Loki struct { runtimeConfig *runtimeconfig.Manager MemberlistKV *memberlist.KVInitService compactor *compactor.Compactor - QueryFrontEndTripperware queryrangebase.Tripperware + QueryFrontEndMiddleware queryrangebase.Middleware queryScheduler *scheduler.Scheduler querySchedulerRingManager *lokiring.RingManager usageReport *analytics.Reporter @@ -590,7 +590,7 @@ func (t *Loki) setupModuleManager() error { mm.RegisterModule(Ingester, t.initIngester) mm.RegisterModule(Querier, t.initQuerier) mm.RegisterModule(IngesterQuerier, t.initIngesterQuerier) - mm.RegisterModule(QueryFrontendTripperware, t.initQueryFrontendTripperware, modules.UserInvisibleModule) + mm.RegisterModule(QueryFrontendTripperware, t.initQueryFrontendMiddleware, modules.UserInvisibleModule) mm.RegisterModule(QueryFrontend, t.initQueryFrontend) mm.RegisterModule(RulerStorage, t.initRulerStorage, modules.UserInvisibleModule) mm.RegisterModule(Ruler, t.initRuler) @@ -653,26 +653,17 @@ func (t *Loki) setupModuleManager() error { level.Debug(util_log.Logger).Log("msg", "per-query request limits support enabled") mm.RegisterModule(QueryLimiter, t.initQueryLimiter, modules.UserInvisibleModule) mm.RegisterModule(QueryLimitsInterceptors, t.initQueryLimitsInterceptors, modules.UserInvisibleModule) - mm.RegisterModule(QueryLimitsTripperware, t.initQueryLimitsTripperware, modules.UserInvisibleModule) + + // This module is defunct but the target remains for backwards compatibility. + mm.RegisterModule(QueryLimitsTripperware, func() (services.Service, error) { return nil, nil }, modules.UserInvisibleModule) // Ensure query limiter embeds overrides after they've been // created. deps[QueryLimiter] = []string{Overrides} deps[QueryLimitsInterceptors] = []string{} - // Ensure query limits tripperware embeds the query frontend - // tripperware after it's been created. Any additional - // middleware/tripperware you want to add to the querier or - // frontend must happen inject a dependence on the query limits - // tripperware. - deps[QueryLimitsTripperware] = []string{QueryFrontendTripperware} - deps[Querier] = append(deps[Querier], QueryLimiter) - // The frontend receives a tripperware. Make sure it uses the - // wrapped one. - deps[QueryFrontend] = append(deps[QueryFrontend], QueryLimitsTripperware) - // query frontend tripperware uses t.Overrides. Make sure it // uses the one wrapped by query limiter. deps[QueryFrontendTripperware] = append(deps[QueryFrontendTripperware], QueryLimiter) diff --git a/pkg/loki/modules.go b/pkg/loki/modules.go index f5a46e2164ed4..dd53f3f262195 100644 --- a/pkg/loki/modules.go +++ b/pkg/loki/modules.go @@ -781,10 +781,10 @@ type disabledShuffleShardingLimits struct{} func (disabledShuffleShardingLimits) MaxQueriersPerUser(_ string) int { return 0 } -func (t *Loki) initQueryFrontendTripperware() (_ services.Service, err error) { +func (t *Loki) initQueryFrontendMiddleware() (_ services.Service, err error) { level.Debug(util_log.Logger).Log("msg", "initializing query frontend tripperware") - tripperware, stopper, err := queryrange.NewTripperware( + middleware, stopper, err := queryrange.NewMiddleware( t.Cfg.QueryRange, t.Cfg.Querier.Engine, util_log.Logger, @@ -797,7 +797,7 @@ func (t *Loki) initQueryFrontendTripperware() (_ services.Service, err error) { return } t.stopper = stopper - t.QueryFrontEndTripperware = tripperware + t.QueryFrontEndMiddleware = middleware return services.NewIdleService(nil, nil), nil } @@ -864,13 +864,15 @@ func (t *Loki) initQueryFrontend() (_ services.Service, err error) { FrontendV2: t.Cfg.Frontend.FrontendV2, DownstreamURL: t.Cfg.Frontend.DownstreamURL, } - roundTripper, frontendV1, frontendV2, err := frontend.InitFrontend( + frontendTripper, frontendV1, frontendV2, err := frontend.InitFrontend( combinedCfg, scheduler.SafeReadRing(t.Cfg.QueryScheduler, t.querySchedulerRingManager), disabledShuffleShardingLimits{}, t.Cfg.Server.GRPCListenPort, util_log.Logger, - prometheus.DefaultRegisterer) + prometheus.DefaultRegisterer, + queryrange.DefaultCodec, + ) if err != nil { return nil, err } @@ -887,7 +889,7 @@ func (t *Loki) initQueryFrontend() (_ services.Service, err error) { level.Debug(util_log.Logger).Log("msg", "no query frontend configured") } - roundTripper = t.QueryFrontEndTripperware(roundTripper) + roundTripper := queryrange.NewSerializeRoundTripper(t.QueryFrontEndMiddleware.Wrap(frontendTripper), queryrange.DefaultCodec) frontendHandler := transport.NewHandler(t.Cfg.Frontend.Handler, roundTripper, util_log.Logger, prometheus.DefaultRegisterer) if t.Cfg.Frontend.CompressResponses { @@ -1477,15 +1479,6 @@ func (t *Loki) initQueryLimitsInterceptors() (services.Service, error) { return nil, nil } -func (t *Loki) initQueryLimitsTripperware() (services.Service, error) { - _ = level.Debug(util_log.Logger).Log("msg", "initializing query limits tripperware") - t.QueryFrontEndTripperware = querylimits.WrapTripperware( - t.QueryFrontEndTripperware, - ) - - return nil, nil -} - func (t *Loki) initAnalytics() (services.Service, error) { if !t.Cfg.Analytics.Enabled { return nil, nil diff --git a/pkg/lokifrontend/frontend/config.go b/pkg/lokifrontend/frontend/config.go index cdbfca34ac6f9..290c097de2669 100644 --- a/pkg/lokifrontend/frontend/config.go +++ b/pkg/lokifrontend/frontend/config.go @@ -12,6 +12,7 @@ import ( "github.com/grafana/loki/pkg/lokifrontend/frontend/transport" v1 "github.com/grafana/loki/pkg/lokifrontend/frontend/v1" v2 "github.com/grafana/loki/pkg/lokifrontend/frontend/v2" + "github.com/grafana/loki/pkg/querier/queryrange/queryrangebase" "github.com/grafana/loki/pkg/util" ) @@ -38,7 +39,7 @@ func (cfg *CombinedFrontendConfig) RegisterFlags(f *flag.FlagSet) { // Returned RoundTripper can be wrapped in more round-tripper middlewares, and then eventually registered // into HTTP server using the Handler from this package. Returned RoundTripper is always non-nil // (if there are no errors), and it uses the returned frontend (if any). -func InitFrontend(cfg CombinedFrontendConfig, ring ring.ReadRing, limits v1.Limits, grpcListenPort int, log log.Logger, reg prometheus.Registerer) (http.RoundTripper, *v1.Frontend, *v2.Frontend, error) { +func InitFrontend(cfg CombinedFrontendConfig, ring ring.ReadRing, limits v1.Limits, grpcListenPort int, log log.Logger, reg prometheus.Registerer, codec transport.Codec) (queryrangebase.Handler, *v1.Frontend, *v2.Frontend, error) { switch { case cfg.DownstreamURL != "": // If the user has specified a downstream Prometheus, then we should use that. @@ -59,8 +60,8 @@ func InitFrontend(cfg CombinedFrontendConfig, ring ring.ReadRing, limits v1.Limi cfg.FrontendV2.Port = grpcListenPort } - fr, err := v2.NewFrontend(cfg.FrontendV2, ring, log, reg) - return transport.AdaptGrpcRoundTripperToHTTPRoundTripper(fr), nil, fr, err + fr, err := v2.NewFrontend(cfg.FrontendV2, ring, log, reg, codec) + return fr, nil, fr, err default: // No scheduler = use original frontend. @@ -68,6 +69,6 @@ func InitFrontend(cfg CombinedFrontendConfig, ring ring.ReadRing, limits v1.Limi if err != nil { return nil, nil, nil, err } - return transport.AdaptGrpcRoundTripperToHTTPRoundTripper(fr), fr, nil, nil + return transport.AdaptGrpcRoundTripperToHandler(fr, codec), fr, nil, nil } } diff --git a/pkg/lokifrontend/frontend/downstream_roundtripper.go b/pkg/lokifrontend/frontend/downstream_roundtripper.go index d52ced81938ab..90f330900c32b 100644 --- a/pkg/lokifrontend/frontend/downstream_roundtripper.go +++ b/pkg/lokifrontend/frontend/downstream_roundtripper.go @@ -1,20 +1,26 @@ package frontend import ( + "context" + "fmt" "net/http" "net/url" "path" + "github.com/grafana/dskit/user" "github.com/opentracing/opentracing-go" + + "github.com/grafana/loki/pkg/querier/queryrange/queryrangebase" ) // RoundTripper that forwards requests to downstream URL. type downstreamRoundTripper struct { downstreamURL *url.URL transport http.RoundTripper + codec queryrangebase.Codec } -func NewDownstreamRoundTripper(downstreamURL string, transport http.RoundTripper) (http.RoundTripper, error) { +func NewDownstreamRoundTripper(downstreamURL string, transport http.RoundTripper) (queryrangebase.Handler, error) { u, err := url.Parse(downstreamURL) if err != nil { return nil, err @@ -23,8 +29,19 @@ func NewDownstreamRoundTripper(downstreamURL string, transport http.RoundTripper return &downstreamRoundTripper{downstreamURL: u, transport: transport}, nil } -func (d downstreamRoundTripper) RoundTrip(r *http.Request) (*http.Response, error) { - tracer, span := opentracing.GlobalTracer(), opentracing.SpanFromContext(r.Context()) +func (d downstreamRoundTripper) Do(ctx context.Context, req queryrangebase.Request) (queryrangebase.Response, error) { + tracer, span := opentracing.GlobalTracer(), opentracing.SpanFromContext(ctx) + + var r *http.Request + + r, err := d.codec.EncodeRequest(ctx, req) + if err != nil { + return nil, fmt.Errorf("connot convert request ot HTTP request: %w", err) + } + if err := user.InjectOrgIDIntoHTTPRequest(ctx, r); err != nil { + return nil, err + } + if tracer != nil && span != nil { carrier := opentracing.HTTPHeadersCarrier(r.Header) err := tracer.Inject(span.Context(), opentracing.HTTPHeaders, carrier) @@ -37,5 +54,16 @@ func (d downstreamRoundTripper) RoundTrip(r *http.Request) (*http.Response, erro r.URL.Host = d.downstreamURL.Host r.URL.Path = path.Join(d.downstreamURL.Path, r.URL.Path) r.Host = "" - return d.transport.RoundTrip(r) + + httpResp, err := d.transport.RoundTrip(r) + if err != nil { + return nil, err + } + + resp, err := d.codec.DecodeResponse(ctx, httpResp, req) + if err != nil { + return nil, fmt.Errorf("cannot convert HTTP response to response: %w", err) + } + + return resp, nil } diff --git a/pkg/lokifrontend/frontend/transport/handler.go b/pkg/lokifrontend/frontend/transport/handler.go index 8cd8ca0bbd8e0..92c29cc896443 100644 --- a/pkg/lokifrontend/frontend/transport/handler.go +++ b/pkg/lokifrontend/frontend/transport/handler.go @@ -16,11 +16,13 @@ import ( "github.com/go-kit/log/level" "github.com/grafana/dskit/httpgrpc" "github.com/grafana/dskit/httpgrpc/server" + "github.com/grafana/dskit/user" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promauto" "github.com/grafana/dskit/tenant" + "github.com/grafana/loki/pkg/querier/queryrange/queryrangebase" querier_stats "github.com/grafana/loki/pkg/querier/stats" "github.com/grafana/loki/pkg/util" util_log "github.com/grafana/loki/pkg/util/log" @@ -252,3 +254,35 @@ func statsValue(name string, d time.Duration) string { durationInMs := strconv.FormatFloat(float64(d)/float64(time.Millisecond), 'f', -1, 64) return name + ";dur=" + durationInMs } + +func AdaptGrpcRoundTripperToHandler(r GrpcRoundTripper, codec Codec) queryrangebase.Handler { + return &grpcRoundTripperToHandlerAdapter{roundTripper: r, codec: codec} +} + +// This adapter wraps GrpcRoundTripper and converts it into a queryrangebase.Handler +type grpcRoundTripperToHandlerAdapter struct { + roundTripper GrpcRoundTripper + codec Codec +} + +func (a *grpcRoundTripperToHandlerAdapter) Do(ctx context.Context, req queryrangebase.Request) (queryrangebase.Response, error) { + httpReq, err := a.codec.EncodeRequest(ctx, req) + if err != nil { + return nil, fmt.Errorf("cannot convert request to HTTP request: %w", err) + } + if err := user.InjectOrgIDIntoHTTPRequest(ctx, httpReq); err != nil { + return nil, err + } + + grpcReq, err := server.HTTPRequest(httpReq) + if err != nil { + return nil, fmt.Errorf("cannot convert HTTP request to gRPC request: %w", err) + } + + grpcResp, err := a.roundTripper.RoundTripGRPC(ctx, grpcReq) + if err != nil { + return nil, err + } + + return a.codec.DecodeHTTPGrpcResponse(grpcResp, req) +} diff --git a/pkg/lokifrontend/frontend/transport/roundtripper.go b/pkg/lokifrontend/frontend/transport/roundtripper.go index 8d8993649555c..58f9d13aa98fe 100644 --- a/pkg/lokifrontend/frontend/transport/roundtripper.go +++ b/pkg/lokifrontend/frontend/transport/roundtripper.go @@ -1,13 +1,11 @@ package transport import ( - "bytes" "context" - "io" - "net/http" "github.com/grafana/dskit/httpgrpc" - "github.com/grafana/dskit/httpgrpc/server" + + "github.com/grafana/loki/pkg/querier/queryrange/queryrangebase" ) // GrpcRoundTripper is similar to http.RoundTripper, but works with HTTP requests converted to protobuf messages. @@ -15,43 +13,7 @@ type GrpcRoundTripper interface { RoundTripGRPC(context.Context, *httpgrpc.HTTPRequest) (*httpgrpc.HTTPResponse, error) } -func AdaptGrpcRoundTripperToHTTPRoundTripper(r GrpcRoundTripper) http.RoundTripper { - return &grpcRoundTripperAdapter{roundTripper: r} -} - -// This adapter wraps GrpcRoundTripper and converted it into http.RoundTripper -type grpcRoundTripperAdapter struct { - roundTripper GrpcRoundTripper -} - -type buffer struct { - buff []byte - io.ReadCloser -} - -func (b *buffer) Bytes() []byte { - return b.buff -} - -func (a *grpcRoundTripperAdapter) RoundTrip(r *http.Request) (*http.Response, error) { - req, err := server.HTTPRequest(r) - if err != nil { - return nil, err - } - - resp, err := a.roundTripper.RoundTripGRPC(r.Context(), req) - if err != nil { - return nil, err - } - - httpResp := &http.Response{ - StatusCode: int(resp.Code), - Body: &buffer{buff: resp.Body, ReadCloser: io.NopCloser(bytes.NewReader(resp.Body))}, - Header: http.Header{}, - ContentLength: int64(len(resp.Body)), - } - for _, h := range resp.Headers { - httpResp.Header[h.Key] = h.Values - } - return httpResp, nil +type Codec interface { + queryrangebase.Codec + DecodeHTTPGrpcResponse(r *httpgrpc.HTTPResponse, req queryrangebase.Request) (queryrangebase.Response, error) } diff --git a/pkg/lokifrontend/frontend/v1/frontend_test.go b/pkg/lokifrontend/frontend/v1/frontend_test.go index cbe34776e6ecf..bd417f4885985 100644 --- a/pkg/lokifrontend/frontend/v1/frontend_test.go +++ b/pkg/lokifrontend/frontend/v1/frontend_test.go @@ -28,6 +28,7 @@ import ( "go.uber.org/atomic" "google.golang.org/grpc" + "github.com/grafana/loki/pkg/loghttp" "github.com/grafana/loki/pkg/lokifrontend/frontend/transport" "github.com/grafana/loki/pkg/lokifrontend/frontend/v1/frontendv1pb" "github.com/grafana/loki/pkg/querier/queryrange" @@ -44,7 +45,7 @@ const ( func TestFrontend(t *testing.T) { handler := queryrangebase.HandlerFunc(func(ctx context.Context, r queryrangebase.Request) (queryrangebase.Response, error) { - return &queryrange.LokiLabelNamesResponse{Data: []string{"Hello", "world"}}, nil + return &queryrange.LokiLabelNamesResponse{Data: []string{"Hello", "world"}, Version: uint32(loghttp.VersionV1)}, nil }) test := func(addr string, _ *Frontend) { req, err := http.NewRequest("GET", fmt.Sprintf("http://%s/%s", addr, labelQuery), nil) @@ -81,7 +82,7 @@ func TestFrontendPropagateTrace(t *testing.T) { traceID := fmt.Sprintf("%v", sp.Context().(jaeger.SpanContext).TraceID()) observedTraceID <- traceID - return &queryrange.LokiLabelNamesResponse{Data: []string{"Hello", "world"}}, nil + return &queryrange.LokiLabelNamesResponse{Data: []string{"Hello", "world"}, Version: uint32(loghttp.VersionV1)}, nil }) test := func(addr string, _ *Frontend) { @@ -186,7 +187,7 @@ func TestFrontendCancel(t *testing.T) { func TestFrontendMetricsCleanup(t *testing.T) { handler := queryrangebase.HandlerFunc(func(ctx context.Context, r queryrangebase.Request) (queryrangebase.Response, error) { - return &queryrange.LokiLabelNamesResponse{Data: []string{"Hello", "world"}}, nil + return &queryrange.LokiLabelNamesResponse{Data: []string{"Hello", "world"}, Version: uint32(loghttp.VersionV1)}, nil }) for _, matchMaxConcurrency := range []bool{false, true} { @@ -260,12 +261,12 @@ func testFrontend(t *testing.T, config Config, handler queryrangebase.Handler, t handlerCfg := transport.HandlerConfig{} flagext.DefaultValues(&handlerCfg) - rt := transport.AdaptGrpcRoundTripperToHTTPRoundTripper(v1) + rt := queryrange.NewSerializeHTTPHandler(transport.AdaptGrpcRoundTripperToHandler(v1, queryrange.DefaultCodec), queryrange.DefaultCodec) r := mux.NewRouter() r.PathPrefix("/").Handler(middleware.Merge( middleware.AuthenticateUser, middleware.Tracer{}, - ).Wrap(transport.NewHandler(handlerCfg, rt, logger, nil))) + ).Wrap(rt)) httpServer := http.Server{ Handler: r, diff --git a/pkg/lokifrontend/frontend/v2/frontend.go b/pkg/lokifrontend/frontend/v2/frontend.go index c085cb86c43d1..0e36d3765ed43 100644 --- a/pkg/lokifrontend/frontend/v2/frontend.go +++ b/pkg/lokifrontend/frontend/v2/frontend.go @@ -14,9 +14,11 @@ import ( "github.com/grafana/dskit/flagext" "github.com/grafana/dskit/grpcclient" "github.com/grafana/dskit/httpgrpc" + "github.com/grafana/dskit/httpgrpc/server" "github.com/grafana/dskit/netutil" "github.com/grafana/dskit/ring" "github.com/grafana/dskit/services" + "github.com/grafana/dskit/user" "github.com/opentracing/opentracing-go" "github.com/pkg/errors" "github.com/prometheus/client_golang/prometheus" @@ -25,7 +27,9 @@ import ( "github.com/grafana/dskit/tenant" + "github.com/grafana/loki/pkg/lokifrontend/frontend/transport" "github.com/grafana/loki/pkg/lokifrontend/frontend/v2/frontendv2pb" + "github.com/grafana/loki/pkg/querier/queryrange/queryrangebase" "github.com/grafana/loki/pkg/querier/stats" lokigrpc "github.com/grafana/loki/pkg/util/httpgrpc" "github.com/grafana/loki/pkg/util/httpreq" @@ -77,8 +81,13 @@ type Frontend struct { schedulerWorkers *frontendSchedulerWorkers requests *requestsInProgress + + codec transport.Codec } +var _ queryrangebase.Handler = &Frontend{} +var _ transport.GrpcRoundTripper = &Frontend{} + type frontendRequest struct { queryID uint64 request *httpgrpc.HTTPRequest @@ -109,7 +118,7 @@ type enqueueResult struct { } // NewFrontend creates a new frontend. -func NewFrontend(cfg Config, ring ring.ReadRing, log log.Logger, reg prometheus.Registerer) (*Frontend, error) { +func NewFrontend(cfg Config, ring ring.ReadRing, log log.Logger, reg prometheus.Registerer, codec transport.Codec) (*Frontend, error) { requestsCh := make(chan *frontendRequest) schedulerWorkers, err := newFrontendSchedulerWorkers(cfg, fmt.Sprintf("%s:%d", cfg.Addr, cfg.Port), ring, requestsCh, log) @@ -123,6 +132,7 @@ func NewFrontend(cfg Config, ring ring.ReadRing, log log.Logger, reg prometheus. requestsCh: requestsCh, schedulerWorkers: schedulerWorkers, requests: newRequestsInProgress(), + codec: codec, } // Randomize to avoid getting responses from queries sent before restart, which could lead to mixing results // between different queries. Note that frontend verifies the user, so it cannot leak results between tenants. @@ -219,32 +229,79 @@ func (f *Frontend) RoundTripGRPC(ctx context.Context, req *httpgrpc.HTTPRequest) response: make(chan *frontendv2pb.QueryResultRequest, 1), } - f.requests.put(freq) + cancelCh, err := f.enqueue(ctx, freq) defer f.requests.delete(freq.queryID) + if err != nil { + return nil, err + } - retries := f.cfg.WorkerConcurrency + 1 // To make sure we hit at least two different schedulers. - -enqueueAgain: - var cancelCh chan<- uint64 select { case <-ctx.Done(): + if cancelCh != nil { + select { + case cancelCh <- freq.queryID: + // cancellation sent. + default: + // failed to cancel, ignore. + level.Warn(f.log).Log("msg", "failed to send cancellation request to scheduler, queue full") + } + } return nil, ctx.Err() - case f.requestsCh <- freq: - // Enqueued, let's wait for response. - enqRes := <-freq.enqueue - - if enqRes.status == waitForResponse { - cancelCh = enqRes.cancelCh - break // go wait for response. - } else if enqRes.status == failed { - retries-- - if retries > 0 { - goto enqueueAgain - } + case resp := <-freq.response: + if stats.ShouldTrackHTTPGRPCResponse(resp.HttpResponse) { + stats := stats.FromContext(ctx) + stats.Merge(resp.Stats) // Safe if stats is nil. } - return nil, httpgrpc.Errorf(http.StatusInternalServerError, "failed to enqueue request") + return resp.HttpResponse, nil + } +} + +// Do implements queryrangebase.Handler analogous to RoundTripGRPC. +func (f *Frontend) Do(ctx context.Context, req queryrangebase.Request) (queryrangebase.Response, error) { + tenantIDs, err := tenant.TenantIDs(ctx) + if err != nil { + return nil, err + } + tenantID := tenant.JoinTenantIDs(tenantIDs) + + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + // For backwards comaptibility we are sending both encodings + httpReq, err := f.codec.EncodeRequest(ctx, req) + if err != nil { + return nil, fmt.Errorf("connot convert request to HTTP request: %w", err) + } + + if err := user.InjectOrgIDIntoHTTPRequest(ctx, httpReq); err != nil { + return nil, httpgrpc.Errorf(http.StatusBadRequest, err.Error()) + } + httpgrpcReq, err := server.HTTPRequest(httpReq) + if err != nil { + return nil, fmt.Errorf("connot convert HTTP request to gRPC request: %w", err) + } + + freq := &frontendRequest{ + queryID: f.lastQueryID.Inc(), + request: httpgrpcReq, + tenantID: tenantID, + actor: httpreq.ExtractActorPath(ctx), + statsEnabled: stats.IsEnabled(ctx), + + cancel: cancel, + + // Buffer of 1 to ensure response or error can be written to the channel + // even if this goroutine goes away due to client context cancellation. + enqueue: make(chan enqueueResult, 1), + response: make(chan *frontendv2pb.QueryResultRequest, 1), + } + + cancelCh, err := f.enqueue(ctx, freq) + defer f.requests.delete(freq.queryID) + if err != nil { + return nil, err } select { @@ -266,8 +323,39 @@ enqueueAgain: stats.Merge(resp.Stats) // Safe if stats is nil. } - return resp.HttpResponse, nil + return f.codec.DecodeHTTPGrpcResponse(resp.HttpResponse, req) + } +} + +func (f *Frontend) enqueue(ctx context.Context, freq *frontendRequest) (chan<- uint64, error) { + f.requests.put(freq) + + retries := f.cfg.WorkerConcurrency + 1 // To make sure we hit at least two different schedulers. + +enqueueAgain: + var cancelCh chan<- uint64 + select { + case <-ctx.Done(): + return cancelCh, ctx.Err() + + case f.requestsCh <- freq: + // Enqueued, let's wait for response. + enqRes := <-freq.enqueue + + if enqRes.status == waitForResponse { + cancelCh = enqRes.cancelCh + break // go wait for response. + } else if enqRes.status == failed { + retries-- + if retries > 0 { + goto enqueueAgain + } + } + + return cancelCh, httpgrpc.Errorf(http.StatusInternalServerError, "failed to enqueue request") } + + return cancelCh, nil } func (f *Frontend) QueryResult(ctx context.Context, qrReq *frontendv2pb.QueryResultRequest) (*frontendv2pb.QueryResultResponse, error) { diff --git a/pkg/lokifrontend/frontend/v2/frontend_test.go b/pkg/lokifrontend/frontend/v2/frontend_test.go index 0c4223747eb09..d70a51852672f 100644 --- a/pkg/lokifrontend/frontend/v2/frontend_test.go +++ b/pkg/lokifrontend/frontend/v2/frontend_test.go @@ -19,6 +19,7 @@ import ( "google.golang.org/grpc" "github.com/grafana/loki/pkg/lokifrontend/frontend/v2/frontendv2pb" + "github.com/grafana/loki/pkg/querier/queryrange" "github.com/grafana/loki/pkg/querier/stats" "github.com/grafana/loki/pkg/scheduler/schedulerpb" "github.com/grafana/loki/pkg/util/test" @@ -46,7 +47,7 @@ func setupFrontend(t *testing.T, schedulerReplyFunc func(f *Frontend, msg *sched cfg.Port = grpcPort logger := log.NewNopLogger() - f, err := NewFrontend(cfg, nil, logger, nil) + f, err := NewFrontend(cfg, nil, logger, nil, queryrange.DefaultCodec) require.NoError(t, err) frontendv2pb.RegisterFrontendForQuerierServer(server, f) diff --git a/pkg/querier/limits/definitions.go b/pkg/querier/limits/definitions.go new file mode 100644 index 0000000000000..cda30b116976d --- /dev/null +++ b/pkg/querier/limits/definitions.go @@ -0,0 +1,22 @@ +package limists + +import ( + "context" + "time" + + "github.com/grafana/loki/pkg/logql" +) + +type TimeRangeLimits interface { + MaxQueryLookback(context.Context, string) time.Duration + MaxQueryLength(context.Context, string) time.Duration +} + +type Limits interface { + logql.Limits + TimeRangeLimits + QueryTimeout(context.Context, string) time.Duration + MaxStreamsMatchersPerQuery(context.Context, string) int + MaxConcurrentTailRequests(context.Context, string) int + MaxEntriesLimitPerQuery(context.Context, string) int +} diff --git a/pkg/querier/querier.go b/pkg/querier/querier.go index f2333415e6e42..8295f02c644c0 100644 --- a/pkg/querier/querier.go +++ b/pkg/querier/querier.go @@ -27,6 +27,7 @@ import ( "github.com/grafana/loki/pkg/logproto" "github.com/grafana/loki/pkg/logql" "github.com/grafana/loki/pkg/logql/syntax" + querier_limits "github.com/grafana/loki/pkg/querier/limits" "github.com/grafana/loki/pkg/storage" "github.com/grafana/loki/pkg/storage/stores/index/stats" listutil "github.com/grafana/loki/pkg/util" @@ -92,14 +93,7 @@ type Querier interface { Volume(ctx context.Context, req *logproto.VolumeRequest) (*logproto.VolumeResponse, error) } -type Limits interface { - logql.Limits - timeRangeLimits - QueryTimeout(context.Context, string) time.Duration - MaxStreamsMatchersPerQuery(context.Context, string) int - MaxConcurrentTailRequests(context.Context, string) int - MaxEntriesLimitPerQuery(context.Context, string) int -} +type Limits querier_limits.Limits // Store is the store interface we need on the querier. type Store interface { @@ -667,12 +661,9 @@ func (q *SingleTenantQuerier) validateQueryRequest(ctx context.Context, req logq return validateQueryTimeRangeLimits(ctx, userID, q.limits, req.GetStart(), req.GetEnd()) } -type timeRangeLimits interface { - MaxQueryLookback(context.Context, string) time.Duration - MaxQueryLength(context.Context, string) time.Duration -} +type TimeRangeLimits querier_limits.TimeRangeLimits -func validateQueryTimeRangeLimits(ctx context.Context, userID string, limits timeRangeLimits, from, through time.Time) (time.Time, time.Time, error) { +func validateQueryTimeRangeLimits(ctx context.Context, userID string, limits TimeRangeLimits, from, through time.Time) (time.Time, time.Time, error) { now := nowFunc() // Clamp the time range based on the max query lookback. maxQueryLookback := limits.MaxQueryLookback(ctx, userID) diff --git a/pkg/querier/querier_test.go b/pkg/querier/querier_test.go index 356ff7fd65469..d89d24a1751b0 100644 --- a/pkg/querier/querier_test.go +++ b/pkg/querier/querier_test.go @@ -1136,7 +1136,7 @@ func Test_validateQueryTimeRangeLimits(t *testing.T) { nowFunc = func() time.Time { return now } tests := []struct { name string - limits timeRangeLimits + limits TimeRangeLimits from time.Time through time.Time wantFrom time.Time diff --git a/pkg/querier/queryrange/codec.go b/pkg/querier/queryrange/codec.go index fdca5a6a5c9b6..5b9611a4f38c5 100644 --- a/pkg/querier/queryrange/codec.go +++ b/pkg/querier/queryrange/codec.go @@ -35,6 +35,7 @@ import ( "github.com/grafana/loki/pkg/util/httpreq" "github.com/grafana/loki/pkg/util/marshal" marshal_legacy "github.com/grafana/loki/pkg/util/marshal/legacy" + "github.com/grafana/loki/pkg/util/querylimits" ) var DefaultCodec = &Codec{} @@ -299,6 +300,7 @@ func (Codec) DecodeRequest(_ context.Context, r *http.Request, _ []string) (quer if err != nil { return nil, httpgrpc.Errorf(http.StatusBadRequest, err.Error()) } + return &LabelRequest{ LabelRequest: *req, path: r.URL.Path, @@ -352,7 +354,7 @@ func (Codec) DecodeRequest(_ context.Context, r *http.Request, _ []string) (quer // labelNamesRoutes is used to extract the name for querying label values. var labelNamesRoutes = regexp.MustCompile(`/loki/api/v1/label/(?P[^/]+)/values`) -// DecodeHTTPGrpcRequest decodes an httpgrp.HTTPrequest to queryrangebase.Request. +// DecodeHTTPGrpcRequest decodes an httpgrp.HTTPRequest to queryrangebase.Request. func (Codec) DecodeHTTPGrpcRequest(ctx context.Context, r *httpgrpc.HTTPRequest) (queryrangebase.Request, context.Context, error) { httpReq, err := http.NewRequest(r.Method, r.Url, io.NopCloser(bytes.NewBuffer(r.Body))) if err != nil { @@ -485,6 +487,15 @@ func (Codec) DecodeHTTPGrpcRequest(ctx context.Context, r *httpgrpc.HTTPRequest) } } +// DecodeHTTPGrpcResponse decodes an httpgrp.HTTPResponse to queryrangebase.Response. +func (Codec) DecodeHTTPGrpcResponse(r *httpgrpc.HTTPResponse, req queryrangebase.Request) (queryrangebase.Response, error) { + headers := make(http.Header) + for _, header := range r.Headers { + headers[header.Key] = header.Values + } + return decodeResponseJSONFrom(r.Body, req, headers) +} + func (Codec) EncodeHTTPGrpcResponse(ctx context.Context, req *httpgrpc.HTTPRequest, res queryrangebase.Response) (*httpgrpc.HTTPResponse, error) { version := loghttp.GetVersion(req.Url) var buf bytes.Buffer @@ -515,6 +526,14 @@ func (c Codec) EncodeRequest(ctx context.Context, r queryrangebase.Request) (*ht header.Set(httpreq.LokiActorPathHeader, actor) } + limits := querylimits.ExtractQueryLimitsContext(ctx) + if limits != nil { + err := querylimits.InjectQueryLimitsHeader(&header, limits) + if err != nil { + return nil, err + } + } + switch request := r.(type) { case *LokiRequest: params := url.Values{ @@ -697,7 +716,6 @@ func (Codec) DecodeResponse(_ context.Context, r *http.Response, req queryrangeb } func decodeResponseJSON(r *http.Response, req queryrangebase.Request) (queryrangebase.Response, error) { - var buf []byte var err error if buffer, ok := r.Body.(Buffer); ok { @@ -709,6 +727,11 @@ func decodeResponseJSON(r *http.Response, req queryrangebase.Request) (queryrang } } + return decodeResponseJSONFrom(buf, req, r.Header) +} + +func decodeResponseJSONFrom(buf []byte, req queryrangebase.Request, headers http.Header) (queryrangebase.Response, error) { + switch req := req.(type) { case *LokiSeriesRequest: var resp loghttp.SeriesResponse @@ -728,7 +751,7 @@ func decodeResponseJSON(r *http.Response, req queryrangebase.Request) (queryrang Status: resp.Status, Version: uint32(loghttp.GetVersion(req.Path)), Data: data, - Headers: httpResponseHeadersToPromResponseHeaders(r.Header), + Headers: httpResponseHeadersToPromResponseHeaders(headers), }, nil case *LabelRequest: var resp loghttp.LabelResponse @@ -739,7 +762,7 @@ func decodeResponseJSON(r *http.Response, req queryrangebase.Request) (queryrang Status: resp.Status, Version: uint32(loghttp.GetVersion(req.Path())), Data: resp.Data, - Headers: httpResponseHeadersToPromResponseHeaders(r.Header), + Headers: httpResponseHeadersToPromResponseHeaders(headers), }, nil case *logproto.IndexStatsRequest: var resp logproto.IndexStatsResponse @@ -748,7 +771,7 @@ func decodeResponseJSON(r *http.Response, req queryrangebase.Request) (queryrang } return &IndexStatsResponse{ Response: &resp, - Headers: httpResponseHeadersToPromResponseHeaders(r.Header), + Headers: httpResponseHeadersToPromResponseHeaders(headers), }, nil case *logproto.VolumeRequest: var resp logproto.VolumeResponse @@ -757,7 +780,7 @@ func decodeResponseJSON(r *http.Response, req queryrangebase.Request) (queryrang } return &VolumeResponse{ Response: &resp, - Headers: httpResponseHeadersToPromResponseHeaders(r.Header), + Headers: httpResponseHeadersToPromResponseHeaders(headers), }, nil default: var resp loghttp.QueryResponse @@ -773,7 +796,7 @@ func decodeResponseJSON(r *http.Response, req queryrangebase.Request) (queryrang ResultType: loghttp.ResultTypeMatrix, Result: toProtoMatrix(resp.Data.Result.(loghttp.Matrix)), }, - Headers: convertPrometheusResponseHeadersToPointers(httpResponseHeadersToPromResponseHeaders(r.Header)), + Headers: convertPrometheusResponseHeadersToPointers(httpResponseHeadersToPromResponseHeaders(headers)), }, Statistics: resp.Data.Statistics, }, nil @@ -803,7 +826,7 @@ func decodeResponseJSON(r *http.Response, req queryrangebase.Request) (queryrang ResultType: loghttp.ResultTypeStream, Result: resp.Data.Result.(loghttp.Streams).ToProto(), }, - Headers: httpResponseHeadersToPromResponseHeaders(r.Header), + Headers: httpResponseHeadersToPromResponseHeaders(headers), }, nil case loghttp.ResultTypeVector: return &LokiPromResponse{ @@ -813,7 +836,7 @@ func decodeResponseJSON(r *http.Response, req queryrangebase.Request) (queryrang ResultType: loghttp.ResultTypeVector, Result: toProtoVector(resp.Data.Result.(loghttp.Vector)), }, - Headers: convertPrometheusResponseHeadersToPointers(httpResponseHeadersToPromResponseHeaders(r.Header)), + Headers: convertPrometheusResponseHeadersToPointers(httpResponseHeadersToPromResponseHeaders(headers)), }, Statistics: resp.Data.Statistics, }, nil @@ -825,7 +848,7 @@ func decodeResponseJSON(r *http.Response, req queryrangebase.Request) (queryrang ResultType: loghttp.ResultTypeScalar, Result: toProtoScalar(resp.Data.Result.(loghttp.Scalar)), }, - Headers: convertPrometheusResponseHeadersToPointers(httpResponseHeadersToPromResponseHeaders(r.Header)), + Headers: convertPrometheusResponseHeadersToPointers(httpResponseHeadersToPromResponseHeaders(headers)), }, Statistics: resp.Data.Statistics, }, nil @@ -877,7 +900,7 @@ func decodeResponseProtobuf(r *http.Response, req queryrangebase.Request) (query case *QueryResponse_QuantileSketches: return concrete.QuantileSketches.WithHeaders(headers), nil default: - return nil, httpgrpc.Errorf(http.StatusInternalServerError, "unsupported response type, got (%t)", resp.Response) + return nil, httpgrpc.Errorf(http.StatusInternalServerError, "unsupported response type, got (%T)", resp.Response) } } } @@ -977,31 +1000,9 @@ func encodeResponseProtobuf(ctx context.Context, res queryrangebase.Response) (* sp, _ := opentracing.StartSpanFromContext(ctx, "codec.EncodeResponse") defer sp.Finish() - p := QueryResponse{} - - switch response := res.(type) { - case *LokiPromResponse: - p.Response = &QueryResponse_Prom{response} - case *LokiResponse: - p.Response = &QueryResponse_Streams{response} - case *LokiSeriesResponse: - p.Response = &QueryResponse_Series{response} - case *MergedSeriesResponseView: - mat, err := response.Materialize() - if err != nil { - return nil, err - } - p.Response = &QueryResponse_Series{mat} - case *LokiLabelNamesResponse: - p.Response = &QueryResponse_Labels{response} - case *IndexStatsResponse: - p.Response = &QueryResponse_Stats{response} - case *TopKSketchesResponse: - p.Response = &QueryResponse_TopkSketches{response} - case *QuantileSketchResponse: - p.Response = &QueryResponse_QuantileSketches{response} - default: - return nil, httpgrpc.Errorf(http.StatusInternalServerError, fmt.Sprintf("invalid response format, got (%T)", res)) + p, err := QueryResponseWrap(res) + if err != nil { + return nil, httpgrpc.Errorf(http.StatusInternalServerError, err.Error()) } buf, err := p.Marshal() diff --git a/pkg/querier/queryrange/codec_test.go b/pkg/querier/queryrange/codec_test.go index 9be6a1c856af2..2131d34eab67a 100644 --- a/pkg/querier/queryrange/codec_test.go +++ b/pkg/querier/queryrange/codec_test.go @@ -72,6 +72,19 @@ func Test_codec_EncodeDecodeRequest(t *testing.T) { StartTs: start, EndTs: end, }, false}, + {"legacy query_range with refexp", func() (*http.Request, error) { + return http.NewRequest(http.MethodGet, + fmt.Sprintf(`/api/prom/query?start=%d&end=%d&query={foo="bar"}&interval=10&limit=200&direction=BACKWARD®exp=foo`, start.UnixNano(), end.UnixNano()), nil) + }, &LokiRequest{ + Query: `{foo="bar"} |~ "foo"`, + Limit: 200, + Step: 14000, // step is expected in ms; calculated default if request param not present + Interval: 10000, // interval is expected in ms + Direction: logproto.BACKWARD, + Path: "/api/prom/query", + StartTs: start, + EndTs: end, + }, false}, {"series", func() (*http.Request, error) { return http.NewRequest(http.MethodGet, fmt.Sprintf(`/series?start=%d&end=%d&match={foo="bar"}`, start.UnixNano(), end.UnixNano()), nil) diff --git a/pkg/querier/queryrange/downstreamer.go b/pkg/querier/queryrange/downstreamer.go index a98998f4ee793..b7a3d2f57a3ff 100644 --- a/pkg/querier/queryrange/downstreamer.go +++ b/pkg/querier/queryrange/downstreamer.go @@ -15,10 +15,8 @@ import ( "github.com/prometheus/prometheus/promql" "github.com/prometheus/prometheus/promql/parser" - "github.com/grafana/loki/pkg/loghttp" "github.com/grafana/loki/pkg/logproto" "github.com/grafana/loki/pkg/logql" - "github.com/grafana/loki/pkg/logql/sketch" "github.com/grafana/loki/pkg/logqlmodel" "github.com/grafana/loki/pkg/logqlmodel/metadata" "github.com/grafana/loki/pkg/logqlmodel/stats" @@ -217,65 +215,6 @@ func sampleStreamToVector(streams []queryrangebase.SampleStream) parser.Value { return xs } -func ResponseToResult(resp queryrangebase.Response) (logqlmodel.Result, error) { - switch r := resp.(type) { - case *LokiResponse: - if r.Error != "" { - return logqlmodel.Result{}, fmt.Errorf("%s: %s", r.ErrorType, r.Error) - } - - streams := make(logqlmodel.Streams, 0, len(r.Data.Result)) - - for _, stream := range r.Data.Result { - streams = append(streams, stream) - } - - return logqlmodel.Result{ - Statistics: r.Statistics, - Data: streams, - Headers: resp.GetHeaders(), - }, nil - - case *LokiPromResponse: - if r.Response.Error != "" { - return logqlmodel.Result{}, fmt.Errorf("%s: %s", r.Response.ErrorType, r.Response.Error) - } - if r.Response.Data.ResultType == loghttp.ResultTypeVector { - return logqlmodel.Result{ - Statistics: r.Statistics, - Data: sampleStreamToVector(r.Response.Data.Result), - Headers: resp.GetHeaders(), - }, nil - } - return logqlmodel.Result{ - Statistics: r.Statistics, - Data: sampleStreamToMatrix(r.Response.Data.Result), - Headers: resp.GetHeaders(), - }, nil - case *TopKSketchesResponse: - matrix, err := sketch.TopKMatrixFromProto(r.Response) - if err != nil { - return logqlmodel.Result{}, fmt.Errorf("cannot decode topk sketch: %w", err) - } - - return logqlmodel.Result{ - Data: matrix, - Headers: resp.GetHeaders(), - }, nil - case *QuantileSketchResponse: - matrix, err := sketch.QuantileSketchMatrixFromProto(r.Response) - if err != nil { - return logqlmodel.Result{}, fmt.Errorf("cannot decode quantile sketch: %w", err) - } - return logqlmodel.Result{ - Data: matrix, - Headers: resp.GetHeaders(), - }, nil - default: - return logqlmodel.Result{}, fmt.Errorf("cannot decode (%T)", resp) - } -} - // downstreamAccumulator is one of two variants: // a logsAccumulator or a bufferedAccumulator. // Which variant is detected on the first call to Accumulate. diff --git a/pkg/querier/queryrange/limits.go b/pkg/querier/queryrange/limits.go index 99967b3e24513..ddf38d30cd004 100644 --- a/pkg/querier/queryrange/limits.go +++ b/pkg/querier/queryrange/limits.go @@ -14,7 +14,7 @@ import ( "github.com/go-kit/log/level" "github.com/grafana/dskit/httpgrpc" "github.com/grafana/dskit/tenant" - "github.com/grafana/dskit/user" + "github.com/opentracing/opentracing-go" "github.com/pkg/errors" "github.com/prometheus/common/model" @@ -25,6 +25,7 @@ import ( "github.com/grafana/loki/pkg/logproto" "github.com/grafana/loki/pkg/logql" "github.com/grafana/loki/pkg/logql/syntax" + queryrange_limits "github.com/grafana/loki/pkg/querier/queryrange/limits" "github.com/grafana/loki/pkg/querier/queryrange/queryrangebase" "github.com/grafana/loki/pkg/storage/config" "github.com/grafana/loki/pkg/storage/stores/index/stats" @@ -49,27 +50,7 @@ var ( ErrMaxQueryParalellism = fmt.Errorf("querying is disabled, please contact your Loki operator") ) -// Limits extends the cortex limits interface with support for per tenant splitby parameters -type Limits interface { - queryrangebase.Limits - logql.Limits - QuerySplitDuration(string) time.Duration - MaxQuerySeries(context.Context, string) int - MaxEntriesLimitPerQuery(context.Context, string) int - MinShardingLookback(string) time.Duration - // TSDBMaxQueryParallelism returns the limit to the number of split queries the - // frontend will process in parallel for TSDB queries. - TSDBMaxQueryParallelism(context.Context, string) int - // TSDBMaxBytesPerShard returns the limit to the number of bytes a single shard - TSDBMaxBytesPerShard(string) int - - RequiredLabels(context.Context, string) []string - RequiredNumberLabels(context.Context, string) int - MaxQueryBytesRead(context.Context, string) int - MaxQuerierBytesRead(context.Context, string) int - MaxStatsCacheFreshness(context.Context, string) time.Duration - VolumeEnabled(string) bool -} +type Limits queryrange_limits.Limits type limits struct { Limits @@ -453,39 +434,33 @@ func (sl *seriesLimiter) isLimitReached() bool { type limitedRoundTripper struct { configs []config.PeriodConfig - next http.RoundTripper + next queryrangebase.Handler limits Limits - codec queryrangebase.Codec middleware queryrangebase.Middleware } +var _ queryrangebase.Handler = limitedRoundTripper{} + // NewLimitedRoundTripper creates a new roundtripper that enforces MaxQueryParallelism to the `next` roundtripper across `middlewares`. -func NewLimitedRoundTripper(next http.RoundTripper, codec queryrangebase.Codec, limits Limits, configs []config.PeriodConfig, middlewares ...queryrangebase.Middleware) http.RoundTripper { +func NewLimitedRoundTripper(next queryrangebase.Handler, limits Limits, configs []config.PeriodConfig, middlewares ...queryrangebase.Middleware) queryrangebase.Handler { transport := limitedRoundTripper{ configs: configs, next: next, - codec: codec, limits: limits, middleware: queryrangebase.MergeMiddlewares(middlewares...), } return transport } -func (rt limitedRoundTripper) RoundTrip(r *http.Request) (*http.Response, error) { +func (rt limitedRoundTripper) Do(c context.Context, request queryrangebase.Request) (queryrangebase.Response, error) { var ( - ctx, cancel = context.WithCancel(r.Context()) + ctx, cancel = context.WithCancel(c) ) defer func() { cancel() }() - // Do not forward any request header. - request, err := rt.codec.DecodeRequest(ctx, r, nil) - if err != nil { - return nil, err - } - if span := opentracing.SpanFromContext(ctx); span != nil { request.LogToSpan(span) } @@ -509,7 +484,7 @@ func (rt limitedRoundTripper) RoundTrip(r *http.Request) (*http.Response, error) sem := semaphore.NewWeighted(int64(parallelism)) - response, err := rt.middleware.Wrap( + return rt.middleware.Wrap( queryrangebase.HandlerFunc(func(ctx context.Context, r queryrangebase.Request) (queryrangebase.Response, error) { // This inner handler is called multiple times by // sharding outer middlewares such as the downstreamer. @@ -523,35 +498,8 @@ func (rt limitedRoundTripper) RoundTrip(r *http.Request) (*http.Response, error) } defer sem.Release(int64(1)) - return rt.do(ctx, r) + return rt.next.Do(ctx, r) })).Do(ctx, request) - if err != nil { - return nil, err - } - - return rt.codec.EncodeResponse(ctx, r, response) -} - -func (rt limitedRoundTripper) do(ctx context.Context, r queryrangebase.Request) (queryrangebase.Response, error) { - sp, ctx := opentracing.StartSpanFromContext(ctx, "limitedRoundTripper.do") - defer sp.Finish() - - request, err := rt.codec.EncodeRequest(ctx, r) - if err != nil { - return nil, err - } - - if err := user.InjectOrgIDIntoHTTPRequest(ctx, request); err != nil { - return nil, httpgrpc.Errorf(http.StatusBadRequest, err.Error()) - } - - response, err := rt.next.RoundTrip(request) - if err != nil { - return nil, err - } - defer func() { _ = response.Body.Close() }() - - return rt.codec.DecodeResponse(ctx, response, r) } // WeightedParallelism will calculate the request parallelism to use @@ -688,13 +636,13 @@ func MinWeightedParallelism(ctx context.Context, tenantIDs []string, configs []c } // validates log entries limits -func validateMaxEntriesLimits(req *http.Request, reqLimit uint32, limits Limits) error { - tenantIDs, err := tenant.TenantIDs(req.Context()) +func validateMaxEntriesLimits(ctx context.Context, reqLimit uint32, limits Limits) error { + tenantIDs, err := tenant.TenantIDs(ctx) if err != nil { return httpgrpc.Errorf(http.StatusBadRequest, err.Error()) } - maxEntriesCapture := func(id string) int { return limits.MaxEntriesLimitPerQuery(req.Context(), id) } + maxEntriesCapture := func(id string) int { return limits.MaxEntriesLimitPerQuery(ctx, id) } maxEntriesLimit := validation.SmallestPositiveNonZeroIntPerTenant(tenantIDs, maxEntriesCapture) if int(reqLimit) > maxEntriesLimit && maxEntriesLimit != 0 { @@ -703,8 +651,8 @@ func validateMaxEntriesLimits(req *http.Request, reqLimit uint32, limits Limits) return nil } -func validateMatchers(req *http.Request, limits Limits, matchers []*labels.Matcher) error { - tenants, err := tenant.TenantIDs(req.Context()) +func validateMatchers(ctx context.Context, limits Limits, matchers []*labels.Matcher) error { + tenants, err := tenant.TenantIDs(ctx) if err != nil { return err } @@ -718,7 +666,7 @@ func validateMatchers(req *http.Request, limits Limits, matchers []*labels.Match // Enforce RequiredLabels limit for _, tenant := range tenants { - required := limits.RequiredLabels(req.Context(), tenant) + required := limits.RequiredLabels(ctx, tenant) var missing []string for _, label := range required { if _, found := actual[label]; !found { @@ -735,7 +683,7 @@ func validateMatchers(req *http.Request, limits Limits, matchers []*labels.Match // The reason to enforce this one after RequiredLabels is to avoid users // from adding enough label matchers to pass the RequiredNumberLabels limit but then // having to modify them to use the ones required by RequiredLabels. - requiredNumberLabelsCapture := func(id string) int { return limits.RequiredNumberLabels(req.Context(), id) } + requiredNumberLabelsCapture := func(id string) int { return limits.RequiredNumberLabels(ctx, id) } if requiredNumberLabels := validation.SmallestPositiveNonZeroIntPerTenant(tenants, requiredNumberLabelsCapture); requiredNumberLabels > 0 { if len(present) < requiredNumberLabels { return fmt.Errorf(requiredNumberLabelsErrTmpl, strings.Join(present, ", "), len(present), requiredNumberLabels) diff --git a/pkg/querier/queryrange/limits/defitions.go b/pkg/querier/queryrange/limits/defitions.go new file mode 100644 index 0000000000000..bc8f7d0ec94bd --- /dev/null +++ b/pkg/querier/queryrange/limits/defitions.go @@ -0,0 +1,32 @@ +package limits + +import ( + "context" + "time" + + "github.com/grafana/loki/pkg/logql" + "github.com/grafana/loki/pkg/querier/queryrange/queryrangebase" +) + +// Limits extends the cortex limits interface with support for per tenant splitby parameters +// They've been extracted to avoid import cycles. +type Limits interface { + queryrangebase.Limits + logql.Limits + QuerySplitDuration(string) time.Duration + MaxQuerySeries(context.Context, string) int + MaxEntriesLimitPerQuery(context.Context, string) int + MinShardingLookback(string) time.Duration + // TSDBMaxQueryParallelism returns the limit to the number of split queries the + // frontend will process in parallel for TSDB queries. + TSDBMaxQueryParallelism(context.Context, string) int + // TSDBMaxBytesPerShard returns the limit to the number of bytes a single shard + TSDBMaxBytesPerShard(string) int + + RequiredLabels(context.Context, string) []string + RequiredNumberLabels(context.Context, string) int + MaxQueryBytesRead(context.Context, string) int + MaxQuerierBytesRead(context.Context, string) int + MaxStatsCacheFreshness(context.Context, string) time.Duration + VolumeEnabled(string) bool +} diff --git a/pkg/querier/queryrange/limits_test.go b/pkg/querier/queryrange/limits_test.go index cca9946f0c161..02c3862dd45a6 100644 --- a/pkg/querier/queryrange/limits_test.go +++ b/pkg/querier/queryrange/limits_test.go @@ -3,7 +3,6 @@ package queryrange import ( "context" "fmt" - "net/http" "sync" "testing" "time" @@ -17,11 +16,10 @@ import ( "gopkg.in/yaml.v2" "github.com/grafana/loki/pkg/logproto" - "github.com/grafana/loki/pkg/logqlmodel/stats" - "github.com/grafana/loki/pkg/querier/queryrange/queryrangebase" + "github.com/grafana/loki/pkg/logqlmodel" + base "github.com/grafana/loki/pkg/querier/queryrange/queryrangebase" "github.com/grafana/loki/pkg/storage/config" util_log "github.com/grafana/loki/pkg/util/log" - "github.com/grafana/loki/pkg/util/marshal" "github.com/grafana/loki/pkg/util/math" ) @@ -56,7 +54,7 @@ func Test_seriesLimiter(t *testing.T) { cfg.CacheIndexStatsResults = false // split in 7 with 2 in // max. l := WithSplitByLimits(fakeLimits{maxSeries: 1, maxQueryParallelism: 2}, time.Hour) - tpw, stopper, err := NewTripperware(cfg, testEngineOpts, util_log.Logger, l, config.SchemaConfig{ + tpw, stopper, err := NewMiddleware(cfg, testEngineOpts, util_log.Logger, l, config.SchemaConfig{ Configs: testSchemas, }, nil, false, nil) if stopper != nil { @@ -75,28 +73,16 @@ func Test_seriesLimiter(t *testing.T) { } ctx := user.InjectOrgID(context.Background(), "1") - req, err := DefaultCodec.EncodeRequest(ctx, lreq) - require.NoError(t, err) - - req = req.WithContext(ctx) - err = user.InjectOrgIDIntoHTTPRequest(ctx, req) - require.NoError(t, err) - - rt, err := newfakeRoundTripper() - require.NoError(t, err) - defer rt.Close() count, h := promqlResult(matrix) - rt.setHandler(h) - - _, err = tpw(rt).RoundTrip(req) + _, err = tpw.Wrap(h).Do(ctx, lreq) require.NoError(t, err) require.Equal(t, 7, *count) // 2 series should not be allowed. c := new(int) m := &sync.Mutex{} - h = http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + h = base.HandlerFunc(func(_ context.Context, req base.Request) (base.Response, error) { m.Lock() defer m.Unlock() defer func() { @@ -104,52 +90,51 @@ func Test_seriesLimiter(t *testing.T) { }() // first time returns a single series if *c == 0 { - if err := marshal.WriteQueryResponseJSON(matrix, stats.Result{}, rw); err != nil { - panic(err) + params, err := ParamsFromRequest(req) + if err != nil { + return nil, err } - return + return ResultToResponse(logqlmodel.Result{Data: matrix}, params) } // second time returns a different series. - if err := marshal.WriteQueryResponseJSON( - promql.Matrix{ - { - Floats: []promql.FPoint{ - { - T: toMs(testTime.Add(-4 * time.Hour)), - F: 0.013333333333333334, - }, + m := promql.Matrix{ + { + Floats: []promql.FPoint{ + { + T: toMs(testTime.Add(-4 * time.Hour)), + F: 0.013333333333333334, }, - Metric: []labels.Label{ - { - Name: "filename", - Value: `/var/hostlog/apport.log`, - }, - { - Name: "job", - Value: "anotherjob", - }, + }, + Metric: []labels.Label{ + { + Name: "filename", + Value: `/var/hostlog/apport.log`, + }, + { + Name: "job", + Value: "anotherjob", }, }, }, - stats.Result{}, - rw); err != nil { - panic(err) } + params, err := ParamsFromRequest(req) + if err != nil { + return nil, err + } + return ResultToResponse(logqlmodel.Result{Data: m}, params) }) - rt.setHandler(h) - _, err = tpw(rt).RoundTrip(req) + _, err = tpw.Wrap(h).Do(ctx, lreq) require.Error(t, err) require.LessOrEqual(t, *c, 4) } func Test_MaxQueryParallelism(t *testing.T) { maxQueryParallelism := 2 - f, err := newfakeRoundTripper() - require.Nil(t, err) + var count atomic.Int32 var max atomic.Int32 - f.setHandler(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + h := base.HandlerFunc(func(_ context.Context, _ base.Request) (base.Response, error) { cur := count.Inc() if cur > max.Load() { max.Store(cur) @@ -157,16 +142,14 @@ func Test_MaxQueryParallelism(t *testing.T) { defer count.Dec() // simulate some work time.Sleep(20 * time.Millisecond) - })) + return base.NewEmptyPrometheusResponse(), nil + }) ctx := user.InjectOrgID(context.Background(), "foo") - r, err := http.NewRequestWithContext(ctx, "GET", "/query_range", http.NoBody) - require.Nil(t, err) - - _, _ = NewLimitedRoundTripper(f, DefaultCodec, fakeLimits{maxQueryParallelism: maxQueryParallelism}, + _, _ = NewLimitedRoundTripper(h, fakeLimits{maxQueryParallelism: maxQueryParallelism}, testSchemas, - queryrangebase.MiddlewareFunc(func(next queryrangebase.Handler) queryrangebase.Handler { - return queryrangebase.HandlerFunc(func(c context.Context, r queryrangebase.Request) (queryrangebase.Response, error) { + base.MiddlewareFunc(func(next base.Handler) base.Handler { + return base.HandlerFunc(func(c context.Context, r base.Request) (base.Response, error) { var wg sync.WaitGroup for i := 0; i < 10; i++ { wg.Add(1) @@ -179,58 +162,52 @@ func Test_MaxQueryParallelism(t *testing.T) { return nil, nil }) }), - ).RoundTrip(r) + ).Do(ctx, &LokiRequest{}) maxFound := int(max.Load()) require.LessOrEqual(t, maxFound, maxQueryParallelism, "max query parallelism: ", maxFound, " went over the configured one:", maxQueryParallelism) } func Test_MaxQueryParallelismLateScheduling(t *testing.T) { maxQueryParallelism := 2 - f, err := newfakeRoundTripper() - require.Nil(t, err) - f.setHandler(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + h := base.HandlerFunc(func(_ context.Context, _ base.Request) (base.Response, error) { // simulate some work time.Sleep(20 * time.Millisecond) - })) + return base.NewEmptyPrometheusResponse(), nil + }) ctx := user.InjectOrgID(context.Background(), "foo") - r, err := http.NewRequestWithContext(ctx, "GET", "/query_range", http.NoBody) - require.Nil(t, err) - - _, _ = NewLimitedRoundTripper(f, DefaultCodec, fakeLimits{maxQueryParallelism: maxQueryParallelism}, + _, err := NewLimitedRoundTripper(h, fakeLimits{maxQueryParallelism: maxQueryParallelism}, testSchemas, - queryrangebase.MiddlewareFunc(func(next queryrangebase.Handler) queryrangebase.Handler { - return queryrangebase.HandlerFunc(func(c context.Context, r queryrangebase.Request) (queryrangebase.Response, error) { + base.MiddlewareFunc(func(next base.Handler) base.Handler { + return base.HandlerFunc(func(c context.Context, r base.Request) (base.Response, error) { for i := 0; i < 10; i++ { go func() { - _, _ = next.Do(c, &LokiRequest{}) + _, _ = next.Do(c, r) }() } return nil, nil }) }), - ).RoundTrip(r) + ).Do(ctx, &LokiRequest{}) + + require.NoError(t, err) } func Test_MaxQueryParallelismDisable(t *testing.T) { maxQueryParallelism := 0 - f, err := newfakeRoundTripper() - require.Nil(t, err) - f.setHandler(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + h := base.HandlerFunc(func(_ context.Context, _ base.Request) (base.Response, error) { // simulate some work time.Sleep(20 * time.Millisecond) - })) + return base.NewEmptyPrometheusResponse(), nil + }) ctx := user.InjectOrgID(context.Background(), "foo") - r, err := http.NewRequestWithContext(ctx, "GET", "/query_range", http.NoBody) - require.Nil(t, err) - - _, err = NewLimitedRoundTripper(f, DefaultCodec, fakeLimits{maxQueryParallelism: maxQueryParallelism}, + _, err := NewLimitedRoundTripper(h, fakeLimits{maxQueryParallelism: maxQueryParallelism}, testSchemas, - queryrangebase.MiddlewareFunc(func(next queryrangebase.Handler) queryrangebase.Handler { - return queryrangebase.HandlerFunc(func(c context.Context, r queryrangebase.Request) (queryrangebase.Response, error) { + base.MiddlewareFunc(func(next base.Handler) base.Handler { + return base.HandlerFunc(func(c context.Context, r base.Request) (base.Response, error) { for i := 0; i < 10; i++ { go func() { _, _ = next.Do(c, &LokiRequest{}) @@ -239,12 +216,12 @@ func Test_MaxQueryParallelismDisable(t *testing.T) { return nil, nil }) }), - ).RoundTrip(r) + ).Do(ctx, &LokiRequest{}) require.Error(t, err) } func Test_MaxQueryLookBack(t *testing.T) { - tpw, stopper, err := NewTripperware(testConfig, testEngineOpts, util_log.Logger, fakeLimits{ + tpw, stopper, err := NewMiddleware(testConfig, testEngineOpts, util_log.Logger, fakeLimits{ maxQueryLookback: 1 * time.Hour, maxQueryParallelism: 1, }, config.SchemaConfig{ @@ -254,9 +231,6 @@ func Test_MaxQueryLookBack(t *testing.T) { defer stopper.Stop() } require.NoError(t, err) - rt, err := newfakeRoundTripper() - require.NoError(t, err) - defer rt.Close() lreq := &LokiRequest{ Query: `{app="foo"} |= "foo"`, @@ -268,15 +242,17 @@ func Test_MaxQueryLookBack(t *testing.T) { } ctx := user.InjectOrgID(context.Background(), "1") - req, err := DefaultCodec.EncodeRequest(ctx, lreq) - require.NoError(t, err) - req = req.WithContext(ctx) - err = user.InjectOrgIDIntoHTTPRequest(ctx, req) - require.NoError(t, err) + called := false + h := base.HandlerFunc(func(context.Context, base.Request) (base.Response, error) { + called = true + return nil, nil + }) - _, err = tpw(rt).RoundTrip(req) + resp, err := tpw.Wrap(h).Do(ctx, lreq) require.NoError(t, err) + require.False(t, called) + require.Equal(t, resp.(*LokiResponse).Status, "success") } func Test_GenerateCacheKey_NoDivideZero(t *testing.T) { @@ -436,19 +412,6 @@ func Test_WeightedParallelism_DivideByZeroError(t *testing.T) { }) } -func getFakeStatsHandler(retBytes uint64) (queryrangebase.Handler, *int, error) { - fakeRT, err := newfakeRoundTripper() - if err != nil { - return nil, nil, err - } - - count, statsHandler := indexStatsResult(logproto.IndexStatsResponse{Bytes: retBytes}) - - fakeRT.setHandler(statsHandler) - - return queryrangebase.NewRoundTripperHandler(fakeRT, DefaultCodec), count, nil -} - func Test_MaxQuerySize(t *testing.T) { const statsBytes = 1000 @@ -569,17 +532,11 @@ func Test_MaxQuerySize(t *testing.T) { }, } { t.Run(tc.desc, func(t *testing.T) { - queryStatsHandler, queryStatsHits, err := getFakeStatsHandler(uint64(statsBytes / math.Max(tc.expectedQueryStatsHits, 1))) - require.NoError(t, err) + queryStatsHits, queryStatsHandler := indexStatsResult(logproto.IndexStatsResponse{Bytes: uint64(statsBytes / math.Max(tc.expectedQueryStatsHits, 1))}) - querierStatsHandler, querierStatsHits, err := getFakeStatsHandler(uint64(statsBytes / math.Max(tc.expectedQuerierStatsHits, 1))) - require.NoError(t, err) - - fakeRT, err := newfakeRoundTripper() - require.NoError(t, err) + querierStatsHits, querierStatsHandler := indexStatsResult(logproto.IndexStatsResponse{Bytes: uint64(statsBytes / math.Max(tc.expectedQuerierStatsHits, 1))}) _, promHandler := promqlResult(matrix) - fakeRT.setHandler(promHandler) lokiReq := &LokiRequest{ Query: tc.query, @@ -591,19 +548,13 @@ func Test_MaxQuerySize(t *testing.T) { } ctx := user.InjectOrgID(context.Background(), "foo") - req, err := DefaultCodec.EncodeRequest(ctx, lokiReq) - require.NoError(t, err) - - req = req.WithContext(ctx) - err = user.InjectOrgIDIntoHTTPRequest(ctx, req) - require.NoError(t, err) - middlewares := []queryrangebase.Middleware{ + middlewares := []base.Middleware{ NewQuerySizeLimiterMiddleware(schemas, testEngineOpts, util_log.Logger, tc.limits, queryStatsHandler), NewQuerierSizeLimiterMiddleware(schemas, testEngineOpts, util_log.Logger, tc.limits, querierStatsHandler), } - _, err = queryrangebase.NewRoundTripper(fakeRT, DefaultCodec, nil, middlewares...).RoundTrip(req) + _, err := base.MergeMiddlewares(middlewares...).Wrap(promHandler).Do(ctx, lokiReq) if tc.shouldErr { require.Error(t, err) @@ -627,7 +578,7 @@ func Test_MaxQuerySize_MaxLookBackPeriod(t *testing.T) { maxQuerierBytesRead: 1 << 10, } - statsHandler := queryrangebase.HandlerFunc(func(_ context.Context, req queryrangebase.Request) (queryrangebase.Response, error) { + statsHandler := base.HandlerFunc(func(_ context.Context, req base.Request) (base.Response, error) { // This is the actual check that we're testing. require.Equal(t, testTime.Add(-engineOpts.MaxLookBackPeriod).UnixMilli(), req.GetStart()) @@ -640,7 +591,7 @@ func Test_MaxQuerySize_MaxLookBackPeriod(t *testing.T) { for _, tc := range []struct { desc string - middleware queryrangebase.Middleware + middleware base.Middleware }{ { desc: "QuerySizeLimiter", @@ -661,7 +612,7 @@ func Test_MaxQuerySize_MaxLookBackPeriod(t *testing.T) { } handler := tc.middleware.Wrap( - queryrangebase.HandlerFunc(func(_ context.Context, req queryrangebase.Request) (queryrangebase.Response, error) { + base.HandlerFunc(func(_ context.Context, req base.Request) (base.Response, error) { return &LokiResponse{}, nil }), ) diff --git a/pkg/querier/queryrange/marshal.go b/pkg/querier/queryrange/marshal.go index 512cfb321b8f4..b177f5cf86324 100644 --- a/pkg/querier/queryrange/marshal.go +++ b/pkg/querier/queryrange/marshal.go @@ -112,6 +112,65 @@ func ResultToResponse(result logqlmodel.Result, params logql.Params) (queryrange return nil, fmt.Errorf("unsupported data type: %t", result.Data) } +func ResponseToResult(resp queryrangebase.Response) (logqlmodel.Result, error) { + switch r := resp.(type) { + case *LokiResponse: + if r.Error != "" { + return logqlmodel.Result{}, fmt.Errorf("%s: %s", r.ErrorType, r.Error) + } + + streams := make(logqlmodel.Streams, 0, len(r.Data.Result)) + + for _, stream := range r.Data.Result { + streams = append(streams, stream) + } + + return logqlmodel.Result{ + Statistics: r.Statistics, + Data: streams, + Headers: resp.GetHeaders(), + }, nil + + case *LokiPromResponse: + if r.Response.Error != "" { + return logqlmodel.Result{}, fmt.Errorf("%s: %s", r.Response.ErrorType, r.Response.Error) + } + if r.Response.Data.ResultType == loghttp.ResultTypeVector { + return logqlmodel.Result{ + Statistics: r.Statistics, + Data: sampleStreamToVector(r.Response.Data.Result), + Headers: resp.GetHeaders(), + }, nil + } + return logqlmodel.Result{ + Statistics: r.Statistics, + Data: sampleStreamToMatrix(r.Response.Data.Result), + Headers: resp.GetHeaders(), + }, nil + case *TopKSketchesResponse: + matrix, err := sketch.TopKMatrixFromProto(r.Response) + if err != nil { + return logqlmodel.Result{}, fmt.Errorf("cannot decode topk sketch: %w", err) + } + + return logqlmodel.Result{ + Data: matrix, + Headers: resp.GetHeaders(), + }, nil + case *QuantileSketchResponse: + matrix, err := sketch.QuantileSketchMatrixFromProto(r.Response) + if err != nil { + return logqlmodel.Result{}, fmt.Errorf("cannot decode quantile sketch: %w", err) + } + return logqlmodel.Result{ + Data: matrix, + Headers: resp.GetHeaders(), + }, nil + default: + return logqlmodel.Result{}, fmt.Errorf("cannot decode (%T)", resp) + } +} + func QueryResponseWrap(res queryrangebase.Response) (*QueryResponse, error) { p := &QueryResponse{} @@ -141,5 +200,4 @@ func QueryResponseWrap(res queryrangebase.Response) (*QueryResponse, error) { } return p, nil - } diff --git a/pkg/querier/queryrange/queryrangebase/roundtrip.go b/pkg/querier/queryrange/queryrangebase/roundtrip.go index 3cfb7ab849a8a..a2dc31be0bbc5 100644 --- a/pkg/querier/queryrange/queryrangebase/roundtrip.go +++ b/pkg/querier/queryrange/queryrangebase/roundtrip.go @@ -18,13 +18,9 @@ package queryrangebase import ( "context" "flag" - "io" "net/http" "time" - "github.com/grafana/dskit/httpgrpc" - "github.com/grafana/dskit/user" - "github.com/opentracing/opentracing-go" "github.com/pkg/errors" ) @@ -116,79 +112,3 @@ type RoundTripFunc func(*http.Request) (*http.Response, error) func (f RoundTripFunc) RoundTrip(r *http.Request) (*http.Response, error) { return f(r) } - -type roundTripper struct { - roundTripperHandler - handler Handler - headers []string -} - -// NewRoundTripper merges a set of middlewares into an handler, then inject it into the `next` roundtripper -// using the codec to translate requests and responses. -func NewRoundTripper(next http.RoundTripper, codec Codec, headers []string, middlewares ...Middleware) http.RoundTripper { - transport := roundTripper{ - roundTripperHandler: roundTripperHandler{ - next: next, - codec: codec, - }, - headers: headers, - } - transport.handler = MergeMiddlewares(middlewares...).Wrap(&transport) - return transport -} - -func (q roundTripper) RoundTrip(r *http.Request) (*http.Response, error) { - // include the headers specified in the roundTripper during decoding the request. - request, err := q.codec.DecodeRequest(r.Context(), r, q.headers) - if err != nil { - return nil, err - } - - if span := opentracing.SpanFromContext(r.Context()); span != nil { - request.LogToSpan(span) - } - - response, err := q.handler.Do(r.Context(), request) - if err != nil { - return nil, err - } - - return q.codec.EncodeResponse(r.Context(), r, response) -} - -type roundTripperHandler struct { - next http.RoundTripper - codec Codec -} - -// NewRoundTripperHandler returns a handler that translates Loki requests into http requests -// and passes down these to the next RoundTripper. -func NewRoundTripperHandler(next http.RoundTripper, codec Codec) Handler { - return roundTripperHandler{ - next: next, - codec: codec, - } -} - -// Do implements Handler. -func (q roundTripperHandler) Do(ctx context.Context, r Request) (Response, error) { - request, err := q.codec.EncodeRequest(ctx, r) - if err != nil { - return nil, err - } - - if err := user.InjectOrgIDIntoHTTPRequest(ctx, request); err != nil { - return nil, httpgrpc.Errorf(http.StatusBadRequest, err.Error()) - } - - response, err := q.next.RoundTrip(request) - if err != nil { - return nil, err - } - defer func() { - _, _ = io.Copy(io.Discard, io.LimitReader(response.Body, 1024)) //nolint:errcheck - response.Body.Close() - }() - - return q.codec.DecodeResponse(ctx, response, r) -} diff --git a/pkg/querier/queryrange/querysharding.go b/pkg/querier/queryrange/querysharding.go index 038e0611f9362..b2af68b55b783 100644 --- a/pkg/querier/queryrange/querysharding.go +++ b/pkg/querier/queryrange/querysharding.go @@ -36,7 +36,6 @@ func NewQueryShardMiddleware( logger log.Logger, confs ShardingConfigs, engineOpts logql.EngineOpts, - _ queryrangebase.Codec, middlewareMetrics *queryrangebase.InstrumentMiddlewareMetrics, shardingMetrics *logql.MapperMetrics, limits Limits, diff --git a/pkg/querier/queryrange/querysharding_test.go b/pkg/querier/queryrange/querysharding_test.go index 1aa2b601057bb..e3e83f967ac04 100644 --- a/pkg/querier/queryrange/querysharding_test.go +++ b/pkg/querier/queryrange/querysharding_test.go @@ -410,7 +410,7 @@ func Test_InstantSharding(t *testing.T) { cpyPeriodConf.RowShards = 3 sharding := NewQueryShardMiddleware(log.NewNopLogger(), ShardingConfigs{ cpyPeriodConf, - }, testEngineOpts, DefaultCodec, queryrangebase.NewInstrumentMiddlewareMetrics(nil), + }, testEngineOpts, queryrangebase.NewInstrumentMiddlewareMetrics(nil), nilShardingMetrics, fakeLimits{ maxSeries: math.MaxInt32, diff --git a/pkg/querier/queryrange/roundtrip.go b/pkg/querier/queryrange/roundtrip.go index 450feb4b286f4..9c409d14a5a9f 100644 --- a/pkg/querier/queryrange/roundtrip.go +++ b/pkg/querier/queryrange/roundtrip.go @@ -17,11 +17,11 @@ import ( "github.com/prometheus/common/model" "github.com/prometheus/prometheus/model/labels" - "github.com/grafana/loki/pkg/loghttp" + "github.com/grafana/loki/pkg/logproto" "github.com/grafana/loki/pkg/logql" "github.com/grafana/loki/pkg/logql/syntax" "github.com/grafana/loki/pkg/logqlmodel/stats" - "github.com/grafana/loki/pkg/querier/queryrange/queryrangebase" + base "github.com/grafana/loki/pkg/querier/queryrange/queryrangebase" "github.com/grafana/loki/pkg/storage/chunk/cache" "github.com/grafana/loki/pkg/storage/config" logutil "github.com/grafana/loki/pkg/util/log" @@ -29,7 +29,7 @@ import ( // Config is the configuration for the queryrange tripperware type Config struct { - queryrangebase.Config `yaml:",inline"` + base.Config `yaml:",inline"` Transformer UserIDTransformer `yaml:"-"` CacheIndexStatsResults bool `yaml:"cache_index_stats_results"` StatsCacheConfig IndexStatsCacheConfig `yaml:"index_stats_results_cache" doc:"description=If a cache config is not specified and cache_index_stats_results is true, the config for the results cache is used."` @@ -76,7 +76,7 @@ func (s StopperWrapper) Stop() { } } -func newResultsCacheFromConfig(cfg queryrangebase.ResultsCacheConfig, registerer prometheus.Registerer, log log.Logger, cacheType stats.CacheType) (cache.Cache, error) { +func newResultsCacheFromConfig(cfg base.ResultsCacheConfig, registerer prometheus.Registerer, log log.Logger, cacheType stats.CacheType) (cache.Cache, error) { if !cache.IsCacheConfigured(cfg.CacheConfig) { return nil, errors.Errorf("%s cache is not configured", cacheType) } @@ -93,17 +93,17 @@ func newResultsCacheFromConfig(cfg queryrangebase.ResultsCacheConfig, registerer return c, nil } -// NewTripperware returns a Tripperware configured with middlewares to align, split and cache requests. -func NewTripperware( +// NewMiddleware returns a Middleware configured with middlewares to align, split and cache requests. +func NewMiddleware( cfg Config, engineOpts logql.EngineOpts, log log.Logger, limits Limits, schema config.SchemaConfig, - cacheGenNumLoader queryrangebase.CacheGenNumberLoader, + cacheGenNumLoader base.CacheGenNumberLoader, retentionEnabled bool, registerer prometheus.Registerer, -) (queryrangebase.Tripperware, Stopper, error) { +) (base.Middleware, Stopper, error) { metrics := NewMetrics(registerer) var ( @@ -148,7 +148,7 @@ func NewTripperware( } } - var codec queryrangebase.Codec = DefaultCodec + var codec base.Codec = DefaultCodec if cfg.RequiredQueryResponseFormat == "protobuf" { codec = &RequestProtobufCodec{} } @@ -165,7 +165,7 @@ func NewTripperware( return nil, nil, err } - limitedTripperware, err := NewLimitedTripperware(cfg, engineOpts, log, limits, schema, codec, metrics, indexStatsTripperware) + limitedTripperware, err := NewLimitedTripperware(cfg, engineOpts, log, limits, schema, metrics, indexStatsTripperware, codec) if err != nil { return nil, nil, err } @@ -177,7 +177,7 @@ func NewTripperware( return nil, nil, err } - seriesTripperware, err := NewSeriesTripperware(cfg, log, limits, codec, metrics, schema) + seriesTripperware, err := NewSeriesTripperware(cfg, log, limits, metrics, schema, DefaultCodec) if err != nil { return nil, nil, err } @@ -197,32 +197,32 @@ func NewTripperware( return nil, nil, err } - return func(next http.RoundTripper) http.RoundTripper { + return base.MiddlewareFunc(func(next base.Handler) base.Handler { var ( - metricRT = metricsTripperware(next) - limitedRT = limitedTripperware(next) - logFilterRT = logFilterTripperware(next) - seriesRT = seriesTripperware(next) - labelsRT = labelsTripperware(next) - instantRT = instantMetricTripperware(next) - statsRT = indexStatsTripperware(next) - seriesVolumeRT = seriesVolumeTripperware(next) + metricRT = metricsTripperware.Wrap(next) + limitedRT = limitedTripperware.Wrap(next) + logFilterRT = logFilterTripperware.Wrap(next) + seriesRT = seriesTripperware.Wrap(next) + labelsRT = labelsTripperware.Wrap(next) + instantRT = instantMetricTripperware.Wrap(next) + statsRT = indexStatsTripperware.Wrap(next) + seriesVolumeRT = seriesVolumeTripperware.Wrap(next) ) return newRoundTripper(log, next, limitedRT, logFilterRT, metricRT, seriesRT, labelsRT, instantRT, statsRT, seriesVolumeRT, limits) - }, StopperWrapper{resultsCache, statsCache, volumeCache}, nil + }), StopperWrapper{resultsCache, statsCache, volumeCache}, nil } type roundTripper struct { logger log.Logger - next, limited, log, metric, series, labels, instantMetric, indexStats, seriesVolume http.RoundTripper + next, limited, log, metric, series, labels, instantMetric, indexStats, seriesVolume base.Handler limits Limits } // newRoundTripper creates a new queryrange roundtripper -func newRoundTripper(logger log.Logger, next, limited, log, metric, series, labels, instantMetric, indexStats, seriesVolume http.RoundTripper, limits Limits) roundTripper { +func newRoundTripper(logger log.Logger, next, limited, log, metric, series, labels, instantMetric, indexStats, seriesVolume base.Handler, limits Limits) roundTripper { return roundTripper{ logger: logger, limited: limited, @@ -238,26 +238,18 @@ func newRoundTripper(logger log.Logger, next, limited, log, metric, series, labe } } -func (r roundTripper) RoundTrip(req *http.Request) (*http.Response, error) { - logger := logutil.WithContext(req.Context(), r.logger) - err := req.ParseForm() - if err != nil { - return nil, httpgrpc.Errorf(http.StatusBadRequest, err.Error()) - } +func (r roundTripper) Do(ctx context.Context, req base.Request) (base.Response, error) { + logger := logutil.WithContext(ctx, r.logger) - switch op := getOperation(req.URL.Path); op { - case QueryRangeOp: - rangeQuery, err := loghttp.ParseRangeQuery(req) - if err != nil { - return nil, httpgrpc.Errorf(http.StatusBadRequest, err.Error()) - } - expr, err := syntax.ParseExpr(rangeQuery.Query) + switch op := req.(type) { + case *LokiRequest: + expr, err := syntax.ParseExpr(op.Query) if err != nil { return nil, httpgrpc.Errorf(http.StatusBadRequest, err.Error()) } - queryHash := logql.HashedQuery(rangeQuery.Query) - level.Info(logger).Log("msg", "executing query", "type", "range", "query", rangeQuery.Query, "length", rangeQuery.End.Sub(rangeQuery.Start), "step", rangeQuery.Step, "query_hash", queryHash) + queryHash := logql.HashedQuery(op.Query) + level.Info(logger).Log("msg", "executing query", "type", "range", "query", op.Query, "length", op.EndTs.Sub(op.StartTs), "step", op.Step, "query_hash", queryHash) switch e := expr.(type) { case syntax.SampleExpr: @@ -268,112 +260,70 @@ func (r roundTripper) RoundTrip(req *http.Request) (*http.Response, error) { } for _, g := range groups { - if err := validateMatchers(req, r.limits, g.Matchers); err != nil { + if err := validateMatchers(ctx, r.limits, g.Matchers); err != nil { return nil, httpgrpc.Errorf(http.StatusBadRequest, err.Error()) } } - return r.metric.RoundTrip(req) + return r.metric.Do(ctx, req) case syntax.LogSelectorExpr: - // Note, this function can mutate the request - expr, err := transformRegexQuery(req, e) - if err != nil { - return nil, httpgrpc.Errorf(http.StatusBadRequest, err.Error()) - } - if err := validateMaxEntriesLimits(req, rangeQuery.Limit, r.limits); err != nil { + if err := validateMaxEntriesLimits(ctx, op.Limit, r.limits); err != nil { return nil, httpgrpc.Errorf(http.StatusBadRequest, err.Error()) } - if err := validateMatchers(req, r.limits, e.Matchers()); err != nil { + if err := validateMatchers(ctx, r.limits, e.Matchers()); err != nil { return nil, httpgrpc.Errorf(http.StatusBadRequest, err.Error()) } // Only filter expressions are query sharded - if !expr.HasFilter() { - return r.limited.RoundTrip(req) + if !e.HasFilter() { + return r.limited.Do(ctx, req) } - return r.log.RoundTrip(req) + return r.log.Do(ctx, req) default: - return r.next.RoundTrip(req) - } - case SeriesOp: - sr, err := loghttp.ParseAndValidateSeriesQuery(req) - if err != nil { - return nil, httpgrpc.Errorf(http.StatusBadRequest, err.Error()) + return r.next.Do(ctx, req) } + case *LokiSeriesRequest: + level.Info(logger).Log("msg", "executing query", "type", "series", "match", logql.PrintMatches(op.Match), "length", op.EndTs.Sub(op.StartTs)) - level.Info(logger).Log("msg", "executing query", "type", "series", "match", logql.PrintMatches(sr.Groups), "length", sr.End.Sub(sr.Start)) + return r.series.Do(ctx, req) + case *LabelRequest: + level.Info(logger).Log("msg", "executing query", "type", "labels", "label", op.Name, "length", op.LabelRequest.End.Sub(*op.LabelRequest.Start), "query", op.Query) - return r.series.RoundTrip(req) - case LabelNamesOp: - lr, err := loghttp.ParseLabelQuery(req) + return r.labels.Do(ctx, req) + case *LokiInstantRequest: + expr, err := syntax.ParseExpr(op.Query) if err != nil { return nil, httpgrpc.Errorf(http.StatusBadRequest, err.Error()) } - level.Info(logger).Log("msg", "executing query", "type", "labels", "label", lr.Name, "length", lr.End.Sub(*lr.Start), "query", lr.Query) - - return r.labels.RoundTrip(req) - case InstantQueryOp: - instantQuery, err := loghttp.ParseInstantQuery(req) - if err != nil { - return nil, httpgrpc.Errorf(http.StatusBadRequest, err.Error()) - } - expr, err := syntax.ParseExpr(instantQuery.Query) - if err != nil { - return nil, httpgrpc.Errorf(http.StatusBadRequest, err.Error()) - } - - queryHash := logql.HashedQuery(instantQuery.Query) - level.Info(logger).Log("msg", "executing query", "type", "instant", "query", instantQuery.Query, "query_hash", queryHash) + queryHash := logql.HashedQuery(op.Query) + level.Info(logger).Log("msg", "executing query", "type", "instant", "query", op.Query, "query_hash", queryHash) switch expr.(type) { case syntax.SampleExpr: - return r.instantMetric.RoundTrip(req) + return r.instantMetric.Do(ctx, req) default: - return r.next.RoundTrip(req) - } - case IndexStatsOp: - statsQuery, err := loghttp.ParseIndexStatsQuery(req) - if err != nil { - return nil, httpgrpc.Errorf(http.StatusBadRequest, err.Error()) + return r.next.Do(ctx, req) } - level.Info(logger).Log("msg", "executing query", "type", "stats", "query", statsQuery.Query, "length", statsQuery.End.Sub(statsQuery.Start)) + case *logproto.IndexStatsRequest: + level.Info(logger).Log("msg", "executing query", "type", "stats", "query", op.Matchers, "length", op.Through.Sub(op.From)) - return r.indexStats.RoundTrip(req) - case VolumeOp: - volumeQuery, err := loghttp.ParseVolumeInstantQuery(req) - if err != nil { - return nil, httpgrpc.Errorf(http.StatusBadRequest, err.Error()) - } - level.Info(logger).Log( - "msg", "executing query", - "type", "volume", - "query", volumeQuery.Query, - "length", volumeQuery.End.Sub(volumeQuery.Start), - "limit", volumeQuery.Limit, - "aggregate_by", volumeQuery.AggregateBy, - ) - - return r.seriesVolume.RoundTrip(req) - case VolumeRangeOp: - volumeQuery, err := loghttp.ParseVolumeRangeQuery(req) - if err != nil { - return nil, httpgrpc.Errorf(http.StatusBadRequest, err.Error()) - } + return r.indexStats.Do(ctx, req) + case *logproto.VolumeRequest: level.Info(logger).Log( "msg", "executing query", "type", "volume_range", - "query", volumeQuery.Query, - "length", volumeQuery.End.Sub(volumeQuery.Start), - "step", volumeQuery.Step, - "limit", volumeQuery.Limit, - "aggregate_by", volumeQuery.AggregateBy, + "query", op.Matchers, + "length", op.Through.Sub(op.From), + "step", op.Step, + "limit", op.Limit, + "aggregate_by", op.AggregateBy, ) - return r.seriesVolume.RoundTrip(req) + return r.seriesVolume.Do(ctx, req) default: - return r.next.RoundTrip(req) + return r.next.Do(ctx, req) } } @@ -434,20 +384,20 @@ func NewLogFilterTripperware( log log.Logger, limits Limits, schema config.SchemaConfig, - codec queryrangebase.Codec, + merger base.Merger, c cache.Cache, metrics *Metrics, - indexStatsTripperware queryrangebase.Tripperware, -) (queryrangebase.Tripperware, error) { - return func(next http.RoundTripper) http.RoundTripper { - statsHandler := queryrangebase.NewRoundTripperHandler(indexStatsTripperware(next), codec) + indexStatsTripperware base.Middleware, +) (base.Middleware, error) { + return base.MiddlewareFunc(func(next base.Handler) base.Handler { + statsHandler := indexStatsTripperware.Wrap(next) - queryRangeMiddleware := []queryrangebase.Middleware{ + queryRangeMiddleware := []base.Middleware{ StatsCollectorMiddleware(), NewLimitsMiddleware(limits), NewQuerySizeLimiterMiddleware(schema.Configs, engineOpts, log, limits, statsHandler), - queryrangebase.InstrumentMiddleware("split_by_interval", metrics.InstrumentMiddlewareMetrics), - SplitByIntervalMiddleware(schema.Configs, limits, codec, splitByTime, metrics.SplitByMetrics), + base.InstrumentMiddleware("split_by_interval", metrics.InstrumentMiddlewareMetrics), + SplitByIntervalMiddleware(schema.Configs, limits, merger, splitByTime, metrics.SplitByMetrics), } if cfg.CacheResults { @@ -455,7 +405,7 @@ func NewLogFilterTripperware( log, limits, c, - func(_ context.Context, r queryrangebase.Request) bool { + func(_ context.Context, r base.Request) bool { return !r.GetCachingOptions().Disabled }, cfg.Transformer, @@ -463,7 +413,7 @@ func NewLogFilterTripperware( ) queryRangeMiddleware = append( queryRangeMiddleware, - queryrangebase.InstrumentMiddleware("log_results_cache", metrics.InstrumentMiddlewareMetrics), + base.InstrumentMiddleware("log_results_cache", metrics.InstrumentMiddlewareMetrics), queryCacheMiddleware, ) } @@ -474,7 +424,6 @@ func NewLogFilterTripperware( log, schema.Configs, engineOpts, - codec, metrics.InstrumentMiddlewareMetrics, // instrumentation is included in the sharding middleware metrics.MiddlewareMapperMetrics.shardMapper, limits, @@ -492,16 +441,16 @@ func NewLogFilterTripperware( if cfg.MaxRetries > 0 { queryRangeMiddleware = append( - queryRangeMiddleware, queryrangebase.InstrumentMiddleware("retry", metrics.InstrumentMiddlewareMetrics), - queryrangebase.NewRetryMiddleware(log, cfg.MaxRetries, metrics.RetryMiddlewareMetrics), + queryRangeMiddleware, base.InstrumentMiddleware("retry", metrics.InstrumentMiddlewareMetrics), + base.NewRetryMiddleware(log, cfg.MaxRetries, metrics.RetryMiddlewareMetrics), ) } if len(queryRangeMiddleware) > 0 { - return NewLimitedRoundTripper(next, codec, limits, schema.Configs, queryRangeMiddleware...) + return NewLimitedRoundTripper(next, limits, schema.Configs, queryRangeMiddleware...) } return next - }, nil + }), nil } // NewLimitedTripperware creates a new frontend tripperware responsible for handling log requests which are label matcher only, no filter expression. @@ -511,32 +460,32 @@ func NewLimitedTripperware( log log.Logger, limits Limits, schema config.SchemaConfig, - codec queryrangebase.Codec, metrics *Metrics, - indexStatsTripperware queryrangebase.Tripperware, -) (queryrangebase.Tripperware, error) { - return func(next http.RoundTripper) http.RoundTripper { - statsHandler := queryrangebase.NewRoundTripperHandler(indexStatsTripperware(next), codec) + indexStatsTripperware base.Middleware, + merger base.Merger, +) (base.Middleware, error) { + return base.MiddlewareFunc(func(next base.Handler) base.Handler { + statsHandler := indexStatsTripperware.Wrap(next) - queryRangeMiddleware := []queryrangebase.Middleware{ + queryRangeMiddleware := []base.Middleware{ StatsCollectorMiddleware(), NewLimitsMiddleware(limits), NewQuerySizeLimiterMiddleware(schema.Configs, engineOpts, log, limits, statsHandler), - queryrangebase.InstrumentMiddleware("split_by_interval", metrics.InstrumentMiddlewareMetrics), + base.InstrumentMiddleware("split_by_interval", metrics.InstrumentMiddlewareMetrics), // Limited queries only need to fetch up to the requested line limit worth of logs, // Our defaults for splitting and parallelism are much too aggressive for large customers and result in // potentially GB of logs being returned by all the shards and splits which will overwhelm the frontend // Therefore we force max parallelism to one so that these queries are executed sequentially. // Below we also fix the number of shards to a static number. - SplitByIntervalMiddleware(schema.Configs, WithMaxParallelism(limits, 1), codec, splitByTime, metrics.SplitByMetrics), + SplitByIntervalMiddleware(schema.Configs, WithMaxParallelism(limits, 1), merger, splitByTime, metrics.SplitByMetrics), NewQuerierSizeLimiterMiddleware(schema.Configs, engineOpts, log, limits, statsHandler), } if len(queryRangeMiddleware) > 0 { - return NewLimitedRoundTripper(next, codec, limits, schema.Configs, queryRangeMiddleware...) + return NewLimitedRoundTripper(next, limits, schema.Configs, queryRangeMiddleware...) } return next - }, nil + }), nil } // NewSeriesTripperware creates a new frontend tripperware responsible for handling series requests @@ -544,24 +493,24 @@ func NewSeriesTripperware( cfg Config, log log.Logger, limits Limits, - codec queryrangebase.Codec, metrics *Metrics, schema config.SchemaConfig, -) (queryrangebase.Tripperware, error) { - queryRangeMiddleware := []queryrangebase.Middleware{ + merger base.Merger, +) (base.Middleware, error) { + queryRangeMiddleware := []base.Middleware{ StatsCollectorMiddleware(), NewLimitsMiddleware(limits), - queryrangebase.InstrumentMiddleware("split_by_interval", metrics.InstrumentMiddlewareMetrics), + base.InstrumentMiddleware("split_by_interval", metrics.InstrumentMiddlewareMetrics), // The Series API needs to pull one chunk per series to extract the label set, which is much cheaper than iterating through all matching chunks. // Force a 24 hours split by for series API, this will be more efficient with our static daily bucket storage. // This would avoid queriers downloading chunks for same series over and over again for serving smaller queries. - SplitByIntervalMiddleware(schema.Configs, WithSplitByLimits(limits, 24*time.Hour), codec, splitByTime, metrics.SplitByMetrics), + SplitByIntervalMiddleware(schema.Configs, WithSplitByLimits(limits, 24*time.Hour), merger, splitByTime, metrics.SplitByMetrics), } if cfg.MaxRetries > 0 { queryRangeMiddleware = append(queryRangeMiddleware, - queryrangebase.InstrumentMiddleware("retry", metrics.InstrumentMiddlewareMetrics), - queryrangebase.NewRetryMiddleware(log, cfg.MaxRetries, metrics.RetryMiddlewareMetrics), + base.InstrumentMiddleware("retry", metrics.InstrumentMiddlewareMetrics), + base.NewRetryMiddleware(log, cfg.MaxRetries, metrics.RetryMiddlewareMetrics), ) } @@ -573,17 +522,17 @@ func NewSeriesTripperware( metrics.InstrumentMiddlewareMetrics, metrics.MiddlewareMapperMetrics.shardMapper, limits, - codec, + merger, ), ) } - return func(next http.RoundTripper) http.RoundTripper { + return base.MiddlewareFunc(func(next base.Handler) base.Handler { if len(queryRangeMiddleware) > 0 { - return NewLimitedRoundTripper(next, codec, limits, schema.Configs, queryRangeMiddleware...) + return NewLimitedRoundTripper(next, limits, schema.Configs, queryRangeMiddleware...) } return next - }, nil + }), nil } // NewLabelsTripperware creates a new frontend tripperware responsible for handling labels requests. @@ -591,33 +540,33 @@ func NewLabelsTripperware( cfg Config, log log.Logger, limits Limits, - codec queryrangebase.Codec, + merger base.Merger, metrics *Metrics, schema config.SchemaConfig, -) (queryrangebase.Tripperware, error) { - queryRangeMiddleware := []queryrangebase.Middleware{ +) (base.Middleware, error) { + queryRangeMiddleware := []base.Middleware{ StatsCollectorMiddleware(), NewLimitsMiddleware(limits), - queryrangebase.InstrumentMiddleware("split_by_interval", metrics.InstrumentMiddlewareMetrics), + base.InstrumentMiddleware("split_by_interval", metrics.InstrumentMiddlewareMetrics), // Force a 24 hours split by for labels API, this will be more efficient with our static daily bucket storage. // This is because the labels API is an index-only operation. - SplitByIntervalMiddleware(schema.Configs, WithSplitByLimits(limits, 24*time.Hour), codec, splitByTime, metrics.SplitByMetrics), + SplitByIntervalMiddleware(schema.Configs, WithSplitByLimits(limits, 24*time.Hour), merger, splitByTime, metrics.SplitByMetrics), } if cfg.MaxRetries > 0 { queryRangeMiddleware = append(queryRangeMiddleware, - queryrangebase.InstrumentMiddleware("retry", metrics.InstrumentMiddlewareMetrics), - queryrangebase.NewRetryMiddleware(log, cfg.MaxRetries, metrics.RetryMiddlewareMetrics), + base.InstrumentMiddleware("retry", metrics.InstrumentMiddlewareMetrics), + base.NewRetryMiddleware(log, cfg.MaxRetries, metrics.RetryMiddlewareMetrics), ) } - return func(next http.RoundTripper) http.RoundTripper { + return base.MiddlewareFunc(func(next base.Handler) base.Handler { if len(queryRangeMiddleware) > 0 { // Do not forward any request header. - return queryrangebase.NewRoundTripper(next, codec, nil, queryRangeMiddleware...) + return base.MergeMiddlewares(queryRangeMiddleware...).Wrap(next) } return next - }, nil + }), nil } // NewMetricTripperware creates a new frontend tripperware responsible for handling metric queries @@ -627,30 +576,30 @@ func NewMetricTripperware( log log.Logger, limits Limits, schema config.SchemaConfig, - codec queryrangebase.Codec, + merger base.Merger, c cache.Cache, - cacheGenNumLoader queryrangebase.CacheGenNumberLoader, + cacheGenNumLoader base.CacheGenNumberLoader, retentionEnabled bool, - extractor queryrangebase.Extractor, + extractor base.Extractor, metrics *Metrics, - indexStatsTripperware queryrangebase.Tripperware, -) (queryrangebase.Tripperware, error) { + indexStatsTripperware base.Middleware, +) (base.Middleware, error) { cacheKey := cacheKeyLimits{limits, cfg.Transformer} - var queryCacheMiddleware queryrangebase.Middleware + var queryCacheMiddleware base.Middleware if cfg.CacheResults { var err error - queryCacheMiddleware, err = queryrangebase.NewResultsCacheMiddleware( + queryCacheMiddleware, err = base.NewResultsCacheMiddleware( log, c, cacheKey, limits, - codec, + merger, extractor, cacheGenNumLoader, - func(_ context.Context, r queryrangebase.Request) bool { + func(_ context.Context, r base.Request) bool { return !r.GetCachingOptions().Disabled }, - func(ctx context.Context, tenantIDs []string, r queryrangebase.Request) int { + func(ctx context.Context, tenantIDs []string, r base.Request) int { return MinWeightedParallelism( ctx, tenantIDs, @@ -668,10 +617,10 @@ func NewMetricTripperware( } } - return func(next http.RoundTripper) http.RoundTripper { - statsHandler := queryrangebase.NewRoundTripperHandler(indexStatsTripperware(next), codec) + return base.MiddlewareFunc(func(next base.Handler) base.Handler { + statsHandler := indexStatsTripperware.Wrap(next) - queryRangeMiddleware := []queryrangebase.Middleware{ + queryRangeMiddleware := []base.Middleware{ StatsCollectorMiddleware(), NewLimitsMiddleware(limits), } @@ -679,22 +628,22 @@ func NewMetricTripperware( if cfg.AlignQueriesWithStep { queryRangeMiddleware = append( queryRangeMiddleware, - queryrangebase.InstrumentMiddleware("step_align", metrics.InstrumentMiddlewareMetrics), - queryrangebase.StepAlignMiddleware, + base.InstrumentMiddleware("step_align", metrics.InstrumentMiddlewareMetrics), + base.StepAlignMiddleware, ) } queryRangeMiddleware = append( queryRangeMiddleware, NewQuerySizeLimiterMiddleware(schema.Configs, engineOpts, log, limits, statsHandler), - queryrangebase.InstrumentMiddleware("split_by_interval", metrics.InstrumentMiddlewareMetrics), - SplitByIntervalMiddleware(schema.Configs, limits, codec, splitMetricByTime, metrics.SplitByMetrics), + base.InstrumentMiddleware("split_by_interval", metrics.InstrumentMiddlewareMetrics), + SplitByIntervalMiddleware(schema.Configs, limits, merger, splitMetricByTime, metrics.SplitByMetrics), ) if cfg.CacheResults { queryRangeMiddleware = append( queryRangeMiddleware, - queryrangebase.InstrumentMiddleware("results_cache", metrics.InstrumentMiddlewareMetrics), + base.InstrumentMiddleware("results_cache", metrics.InstrumentMiddlewareMetrics), queryCacheMiddleware, ) } @@ -705,7 +654,6 @@ func NewMetricTripperware( log, schema.Configs, engineOpts, - codec, metrics.InstrumentMiddlewareMetrics, // instrumentation is included in the sharding middleware metrics.MiddlewareMapperMetrics.shardMapper, limits, @@ -724,23 +672,24 @@ func NewMetricTripperware( if cfg.MaxRetries > 0 { queryRangeMiddleware = append( queryRangeMiddleware, - queryrangebase.InstrumentMiddleware("retry", metrics.InstrumentMiddlewareMetrics), - queryrangebase.NewRetryMiddleware(log, cfg.MaxRetries, metrics.RetryMiddlewareMetrics), + base.InstrumentMiddleware("retry", metrics.InstrumentMiddlewareMetrics), + base.NewRetryMiddleware(log, cfg.MaxRetries, metrics.RetryMiddlewareMetrics), ) } // Finally, if the user selected any query range middleware, stitch it in. if len(queryRangeMiddleware) > 0 { - rt := NewLimitedRoundTripper(next, codec, limits, schema.Configs, queryRangeMiddleware...) - return queryrangebase.RoundTripFunc(func(r *http.Request) (*http.Response, error) { - if !strings.HasSuffix(r.URL.Path, "/query_range") { - return next.RoundTrip(r) + rt := NewLimitedRoundTripper(next, limits, schema.Configs, queryRangeMiddleware...) + return base.HandlerFunc(func(ctx context.Context, r base.Request) (base.Response, error) { + _, ok := r.(*LokiRequest) + if !ok { + return next.Do(ctx, r) } - return rt.RoundTrip(r) + return rt.Do(ctx, r) }) } return next - }, nil + }), nil } // NewInstantMetricTripperware creates a new frontend tripperware responsible for handling metric queries @@ -750,14 +699,14 @@ func NewInstantMetricTripperware( log log.Logger, limits Limits, schema config.SchemaConfig, - codec queryrangebase.Codec, + merger base.Merger, metrics *Metrics, - indexStatsTripperware queryrangebase.Tripperware, -) (queryrangebase.Tripperware, error) { - return func(next http.RoundTripper) http.RoundTripper { - statsHandler := queryrangebase.NewRoundTripperHandler(indexStatsTripperware(next), codec) + indexStatsTripperware base.Middleware, +) (base.Middleware, error) { + return base.MiddlewareFunc(func(next base.Handler) base.Handler { + statsHandler := indexStatsTripperware.Wrap(next) - queryRangeMiddleware := []queryrangebase.Middleware{ + queryRangeMiddleware := []base.Middleware{ StatsCollectorMiddleware(), NewLimitsMiddleware(limits), NewQuerySizeLimiterMiddleware(schema.Configs, engineOpts, log, limits, statsHandler), @@ -770,7 +719,6 @@ func NewInstantMetricTripperware( log, schema.Configs, engineOpts, - codec, metrics.InstrumentMiddlewareMetrics, // instrumentation is included in the sharding middleware metrics.MiddlewareMapperMetrics.shardMapper, limits, @@ -783,16 +731,16 @@ func NewInstantMetricTripperware( if cfg.MaxRetries > 0 { queryRangeMiddleware = append( queryRangeMiddleware, - queryrangebase.InstrumentMiddleware("retry", metrics.InstrumentMiddlewareMetrics), - queryrangebase.NewRetryMiddleware(log, cfg.MaxRetries, metrics.RetryMiddlewareMetrics), + base.InstrumentMiddleware("retry", metrics.InstrumentMiddlewareMetrics), + base.NewRetryMiddleware(log, cfg.MaxRetries, metrics.RetryMiddlewareMetrics), ) } if len(queryRangeMiddleware) > 0 { - return NewLimitedRoundTripper(next, codec, limits, schema.Configs, queryRangeMiddleware...) + return NewLimitedRoundTripper(next, limits, schema.Configs, queryRangeMiddleware...) } return next - }, nil + }), nil } func NewVolumeTripperware( @@ -800,28 +748,28 @@ func NewVolumeTripperware( log log.Logger, limits Limits, schema config.SchemaConfig, - codec queryrangebase.Codec, + merger base.Merger, c cache.Cache, - cacheGenNumLoader queryrangebase.CacheGenNumberLoader, + cacheGenNumLoader base.CacheGenNumberLoader, retentionEnabled bool, metrics *Metrics, -) (queryrangebase.Tripperware, error) { +) (base.Middleware, error) { // Parallelize the volume requests, so it doesn't send a huge request to a single index-gw (i.e. {app=~".+"} for 30d). // Indices are sharded by 24 hours, so we split the volume request in 24h intervals. limits = WithSplitByLimits(limits, 24*time.Hour) - var cacheMiddleware queryrangebase.Middleware + var cacheMiddleware base.Middleware if cfg.CacheVolumeResults { var err error cacheMiddleware, err = NewVolumeCacheMiddleware( log, limits, - codec, + merger, c, cacheGenNumLoader, - func(_ context.Context, r queryrangebase.Request) bool { + func(_ context.Context, r base.Request) bool { return !r.GetCachingOptions().Disabled }, - func(ctx context.Context, tenantIDs []string, r queryrangebase.Request) int { + func(ctx context.Context, tenantIDs []string, r base.Request) int { return MinWeightedParallelism( ctx, tenantIDs, @@ -843,7 +791,7 @@ func NewVolumeTripperware( indexTw, err := sharedIndexTripperware( cacheMiddleware, cfg, - codec, + merger, limits, log, metrics, @@ -855,47 +803,33 @@ func NewVolumeTripperware( } return volumeFeatureFlagRoundTripper( - volumeRangeTripperware(codec, indexTw), + volumeRangeTripperware(indexTw), limits, ), nil } -func volumeRangeTripperware(codec queryrangebase.Codec, nextTW queryrangebase.Tripperware) func(next http.RoundTripper) http.RoundTripper { - return func(next http.RoundTripper) http.RoundTripper { - nextRT := nextTW(next) - - return queryrangebase.RoundTripFunc(func(r *http.Request) (*http.Response, error) { - request, err := codec.DecodeRequest(r.Context(), r, nil) - if err != nil { - return nil, err - } - - seriesVolumeMiddlewares := []queryrangebase.Middleware{ +func volumeRangeTripperware(nextTW base.Middleware) base.Middleware { + return base.MiddlewareFunc(func(next base.Handler) base.Handler { + return base.HandlerFunc(func(ctx context.Context, r base.Request) (base.Response, error) { + seriesVolumeMiddlewares := []base.Middleware{ StatsCollectorMiddleware(), NewVolumeMiddleware(), + nextTW, } // wrap nextRT with our new middleware - response, err := queryrangebase.MergeMiddlewares( + return base.MergeMiddlewares( seriesVolumeMiddlewares..., - ).Wrap( - VolumeDownstreamHandler(nextRT, codec), - ).Do(r.Context(), request) - - if err != nil { - return nil, err - } - - return codec.EncodeResponse(r.Context(), r, response) + ).Wrap(next).Do(ctx, r) }) - } + }) } -func volumeFeatureFlagRoundTripper(nextTW queryrangebase.Tripperware, limits Limits) func(next http.RoundTripper) http.RoundTripper { - return func(next http.RoundTripper) http.RoundTripper { - nextRt := nextTW(next) - return queryrangebase.RoundTripFunc(func(r *http.Request) (*http.Response, error) { - userID, err := user.ExtractOrgID(r.Context()) +func volumeFeatureFlagRoundTripper(nextTW base.Middleware, limits Limits) base.Middleware { + return base.MiddlewareFunc(func(next base.Handler) base.Handler { + nextRt := nextTW.Wrap(next) + return base.HandlerFunc(func(ctx context.Context, r base.Request) (base.Response, error) { + userID, err := user.ExtractOrgID(ctx) if err != nil { return nil, err } @@ -904,9 +838,9 @@ func volumeFeatureFlagRoundTripper(nextTW queryrangebase.Tripperware, limits Lim return nil, httpgrpc.Errorf(http.StatusNotFound, "not found") } - return nextRt.RoundTrip(r) + return nextRt.Do(ctx, r) }) - } + }) } func NewIndexStatsTripperware( @@ -914,29 +848,29 @@ func NewIndexStatsTripperware( log log.Logger, limits Limits, schema config.SchemaConfig, - codec queryrangebase.Codec, + merger base.Merger, c cache.Cache, - cacheGenNumLoader queryrangebase.CacheGenNumberLoader, + cacheGenNumLoader base.CacheGenNumberLoader, retentionEnabled bool, metrics *Metrics, -) (queryrangebase.Tripperware, error) { +) (base.Middleware, error) { // Parallelize the index stats requests, so it doesn't send a huge request to a single index-gw (i.e. {app=~".+"} for 30d). // Indices are sharded by 24 hours, so we split the stats request in 24h intervals. limits = WithSplitByLimits(limits, 24*time.Hour) - var cacheMiddleware queryrangebase.Middleware + var cacheMiddleware base.Middleware if cfg.CacheIndexStatsResults { var err error cacheMiddleware, err = NewIndexStatsCacheMiddleware( log, limits, - codec, + merger, c, cacheGenNumLoader, - func(_ context.Context, r queryrangebase.Request) bool { + func(_ context.Context, r base.Request) bool { return !r.GetCachingOptions().Disabled }, - func(ctx context.Context, tenantIDs []string, r queryrangebase.Request) int { + func(ctx context.Context, tenantIDs []string, r base.Request) int { return MinWeightedParallelism( ctx, tenantIDs, @@ -958,7 +892,7 @@ func NewIndexStatsTripperware( return sharedIndexTripperware( cacheMiddleware, cfg, - codec, + merger, limits, log, metrics, @@ -967,25 +901,25 @@ func NewIndexStatsTripperware( } func sharedIndexTripperware( - cacheMiddleware queryrangebase.Middleware, + cacheMiddleware base.Middleware, cfg Config, - codec queryrangebase.Codec, + merger base.Merger, limits Limits, log log.Logger, metrics *Metrics, schema config.SchemaConfig, -) (queryrangebase.Tripperware, error) { - return func(next http.RoundTripper) http.RoundTripper { - middlewares := []queryrangebase.Middleware{ +) (base.Middleware, error) { + return base.MiddlewareFunc(func(next base.Handler) base.Handler { + middlewares := []base.Middleware{ NewLimitsMiddleware(limits), - queryrangebase.InstrumentMiddleware("split_by_interval", metrics.InstrumentMiddlewareMetrics), - SplitByIntervalMiddleware(schema.Configs, limits, codec, splitByTime, metrics.SplitByMetrics), + base.InstrumentMiddleware("split_by_interval", metrics.InstrumentMiddlewareMetrics), + SplitByIntervalMiddleware(schema.Configs, limits, merger, splitByTime, metrics.SplitByMetrics), } if cacheMiddleware != nil { middlewares = append( middlewares, - queryrangebase.InstrumentMiddleware("log_results_cache", metrics.InstrumentMiddlewareMetrics), + base.InstrumentMiddleware("log_results_cache", metrics.InstrumentMiddlewareMetrics), cacheMiddleware, ) } @@ -993,11 +927,11 @@ func sharedIndexTripperware( if cfg.MaxRetries > 0 { middlewares = append( middlewares, - queryrangebase.InstrumentMiddleware("retry", metrics.InstrumentMiddlewareMetrics), - queryrangebase.NewRetryMiddleware(log, cfg.MaxRetries, metrics.RetryMiddlewareMetrics), + base.InstrumentMiddleware("retry", metrics.InstrumentMiddlewareMetrics), + base.NewRetryMiddleware(log, cfg.MaxRetries, metrics.RetryMiddlewareMetrics), ) } - return queryrangebase.NewRoundTripper(next, codec, nil, middlewares...) - }, nil + return base.MergeMiddlewares(middlewares...).Wrap(next) + }), nil } diff --git a/pkg/querier/queryrange/roundtrip_test.go b/pkg/querier/queryrange/roundtrip_test.go index a06dc98e93f34..6c8f6d8af5943 100644 --- a/pkg/querier/queryrange/roundtrip_test.go +++ b/pkg/querier/queryrange/roundtrip_test.go @@ -1,22 +1,17 @@ package queryrange import ( - "bytes" "context" + "errors" "fmt" - "io" "math" "net/http" - "net/http/httptest" - "net/url" "sort" - "strconv" "sync" "testing" "time" "github.com/grafana/dskit/httpgrpc" - "github.com/grafana/dskit/middleware" "github.com/grafana/dskit/user" "github.com/prometheus/common/model" "github.com/prometheus/prometheus/model/labels" @@ -30,12 +25,12 @@ import ( "github.com/grafana/loki/pkg/logql" "github.com/grafana/loki/pkg/logqlmodel" "github.com/grafana/loki/pkg/logqlmodel/stats" - "github.com/grafana/loki/pkg/querier/queryrange/queryrangebase" + base "github.com/grafana/loki/pkg/querier/queryrange/queryrangebase" "github.com/grafana/loki/pkg/storage/chunk/cache" "github.com/grafana/loki/pkg/storage/config" + "github.com/grafana/loki/pkg/storage/stores/index/seriesvolume" "github.com/grafana/loki/pkg/util" util_log "github.com/grafana/loki/pkg/util/log" - "github.com/grafana/loki/pkg/util/marshal" "github.com/grafana/loki/pkg/util/validation" valid "github.com/grafana/loki/pkg/validation" ) @@ -43,11 +38,11 @@ import ( var ( testTime = time.Date(2019, 12, 2, 11, 10, 10, 10, time.UTC) testConfig = Config{ - Config: queryrangebase.Config{ + Config: base.Config{ AlignQueriesWithStep: true, MaxRetries: 3, CacheResults: true, - ResultsCacheConfig: queryrangebase.ResultsCacheConfig{ + ResultsCacheConfig: base.ResultsCacheConfig{ CacheConfig: cache.Config{ EmbeddedCache: cache.EmbeddedCacheConfig{ Enabled: true, @@ -60,7 +55,7 @@ var ( Transformer: nil, CacheIndexStatsResults: true, StatsCacheConfig: IndexStatsCacheConfig{ - ResultsCacheConfig: queryrangebase.ResultsCacheConfig{ + ResultsCacheConfig: base.ResultsCacheConfig{ CacheConfig: cache.Config{ EmbeddedCache: cache.EmbeddedCacheConfig{ Enabled: true, @@ -71,7 +66,7 @@ var ( }, }, VolumeCacheConfig: VolumeCacheConfig{ - ResultsCacheConfig: queryrangebase.ResultsCacheConfig{ + ResultsCacheConfig: base.ResultsCacheConfig{ CacheConfig: cache.Config{ EmbeddedCache: cache.EmbeddedCacheConfig{ Enabled: true, @@ -152,19 +147,16 @@ var ( } ) -func getQueryAndStatsHandler(queryHandler, statsHandler http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.URL.Path == "/loki/api/v1/index/stats" { - statsHandler.ServeHTTP(w, r) - return +func getQueryAndStatsHandler(queryHandler, statsHandler base.Handler) base.Handler { + return base.HandlerFunc(func(ctx context.Context, r base.Request) (base.Response, error) { + switch r.(type) { + case *logproto.IndexStatsRequest: + return statsHandler.Do(ctx, r) + case *LokiRequest, *LokiInstantRequest: + return queryHandler.Do(ctx, r) } - if r.URL.Path == "/loki/api/v1/query_range" || r.URL.Path == "/loki/api/v1/query" { - queryHandler.ServeHTTP(w, r) - return - } - - panic("Request not supported") + return nil, fmt.Errorf("Request not supported: %T", r) }) } @@ -181,7 +173,7 @@ func TestMetricsTripperware(t *testing.T) { noCacheTestCfg := testConfig noCacheTestCfg.CacheResults = false noCacheTestCfg.CacheIndexStatsResults = false - tpw, stopper, err := NewTripperware(noCacheTestCfg, testEngineOpts, util_log.Logger, l, config.SchemaConfig{ + tpw, stopper, err := NewMiddleware(noCacheTestCfg, testEngineOpts, util_log.Logger, l, config.SchemaConfig{ Configs: testSchemasTSDB, }, nil, false, nil) if stopper != nil { @@ -200,20 +192,12 @@ func TestMetricsTripperware(t *testing.T) { } ctx := user.InjectOrgID(context.Background(), "1") - req, err := DefaultCodec.EncodeRequest(ctx, lreq) - require.NoError(t, err) - - req = req.WithContext(ctx) - err = user.InjectOrgIDIntoHTTPRequest(ctx, req) - require.NoError(t, err) - rt, err := newfakeRoundTripper() - require.NoError(t, err) // Test MaxQueryBytesRead limit statsCount, statsHandler := indexStatsResult(logproto.IndexStatsResponse{Bytes: 2000}) queryCount, queryHandler := counter() - rt.setHandler(getQueryAndStatsHandler(queryHandler, statsHandler)) - _, err = tpw(rt).RoundTrip(req) + h := getQueryAndStatsHandler(queryHandler, statsHandler) + _, err = tpw.Wrap(h).Do(ctx, lreq) require.Error(t, err) require.Equal(t, 1, *statsCount) require.Equal(t, 0, *queryCount) @@ -221,28 +205,23 @@ func TestMetricsTripperware(t *testing.T) { // Test MaxQuerierBytesRead limit statsCount, statsHandler = indexStatsResult(logproto.IndexStatsResponse{Bytes: 200}) queryCount, queryHandler = counter() - rt.setHandler(getQueryAndStatsHandler(queryHandler, statsHandler)) - _, err = tpw(rt).RoundTrip(req) + h = getQueryAndStatsHandler(queryHandler, statsHandler) + _, err = tpw.Wrap(h).Do(ctx, lreq) require.Error(t, err) require.Equal(t, 0, *queryCount) require.Equal(t, 2, *statsCount) // testing retry _, statsHandler = indexStatsResult(logproto.IndexStatsResponse{Bytes: 10}) - retries, queryHandler := counter() - rt.setHandler(getQueryAndStatsHandler(queryHandler, statsHandler)) - _, err = tpw(rt).RoundTrip(req) + retries, queryHandler := counterWithError(errors.New("handle error")) + h = getQueryAndStatsHandler(queryHandler, statsHandler) + _, err = tpw.Wrap(h).Do(ctx, lreq) // 3 retries configured. require.GreaterOrEqual(t, *retries, 3) require.Error(t, err) - rt.Close() - - rt, err = newfakeRoundTripper() - require.NoError(t, err) - defer rt.Close() // Configure with cache - tpw, stopper, err = NewTripperware(testConfig, testEngineOpts, util_log.Logger, l, config.SchemaConfig{ + tpw, stopper, err = NewMiddleware(testConfig, testEngineOpts, util_log.Logger, l, config.SchemaConfig{ Configs: testSchemasTSDB, }, nil, false, nil) if stopper != nil { @@ -253,23 +232,19 @@ func TestMetricsTripperware(t *testing.T) { // testing split interval _, statsHandler = indexStatsResult(logproto.IndexStatsResponse{Bytes: 10}) count, queryHandler := promqlResult(matrix) - rt.setHandler(getQueryAndStatsHandler(queryHandler, statsHandler)) - resp, err := tpw(rt).RoundTrip(req) + h = getQueryAndStatsHandler(queryHandler, statsHandler) + lokiResponse, err := tpw.Wrap(h).Do(ctx, lreq) // 2 queries require.Equal(t, 2, *count) require.NoError(t, err) - lokiResponse, err := DefaultCodec.DecodeResponse(ctx, resp, lreq) - require.NoError(t, err) // testing cache count, queryHandler = counter() - rt.setHandler(getQueryAndStatsHandler(queryHandler, statsHandler)) - cacheResp, err := tpw(rt).RoundTrip(req) + h = getQueryAndStatsHandler(queryHandler, statsHandler) + lokiCacheResponse, err := tpw.Wrap(h).Do(ctx, lreq) // 0 queries result are cached. require.Equal(t, 0, *count) require.NoError(t, err) - lokiCacheResponse, err := DefaultCodec.DecodeResponse(ctx, cacheResp, lreq) - require.NoError(t, err) require.Equal(t, lokiResponse.(*LokiPromResponse).Response, lokiCacheResponse.(*LokiPromResponse).Response) } @@ -284,14 +259,11 @@ func TestLogFilterTripperware(t *testing.T) { noCacheTestCfg := testConfig noCacheTestCfg.CacheResults = false noCacheTestCfg.CacheIndexStatsResults = false - tpw, stopper, err := NewTripperware(noCacheTestCfg, testEngineOpts, util_log.Logger, l, config.SchemaConfig{Configs: testSchemasTSDB}, nil, false, nil) + tpw, stopper, err := NewMiddleware(noCacheTestCfg, testEngineOpts, util_log.Logger, l, config.SchemaConfig{Configs: testSchemasTSDB}, nil, false, nil) if stopper != nil { defer stopper.Stop() } require.NoError(t, err) - rt, err := newfakeRoundTripper() - require.NoError(t, err) - defer rt.Close() lreq := &LokiRequest{ Query: `{app="foo"} |= "foo"`, @@ -303,38 +275,29 @@ func TestLogFilterTripperware(t *testing.T) { } ctx := user.InjectOrgID(context.Background(), "1") - req, err := DefaultCodec.EncodeRequest(ctx, lreq) - require.NoError(t, err) - - req = req.WithContext(ctx) - err = user.InjectOrgIDIntoHTTPRequest(ctx, req) - require.NoError(t, err) // testing limit count, h := promqlResult(streams) - rt.setHandler(h) - _, err = tpw(rt).RoundTrip(req) + _, err = tpw.Wrap(h).Do(ctx, lreq) require.Equal(t, 0, *count) require.Error(t, err) // set the query length back to normal lreq.StartTs = testTime.Add(-6 * time.Hour) - req, err = DefaultCodec.EncodeRequest(ctx, lreq) - require.NoError(t, err) // testing retry _, statsHandler := indexStatsResult(logproto.IndexStatsResponse{Bytes: 10}) - retries, queryHandler := counter() - rt.setHandler(getQueryAndStatsHandler(queryHandler, statsHandler)) - _, err = tpw(rt).RoundTrip(req) + retries, queryHandler := counterWithError(errors.New("handler failed")) + h = getQueryAndStatsHandler(queryHandler, statsHandler) + _, err = tpw.Wrap(h).Do(ctx, lreq) require.GreaterOrEqual(t, *retries, 3) require.Error(t, err) // Test MaxQueryBytesRead limit statsCount, statsHandler := indexStatsResult(logproto.IndexStatsResponse{Bytes: 2000}) queryCount, queryHandler := counter() - rt.setHandler(getQueryAndStatsHandler(queryHandler, statsHandler)) - _, err = tpw(rt).RoundTrip(req) + h = getQueryAndStatsHandler(queryHandler, statsHandler) + _, err = tpw.Wrap(h).Do(ctx, lreq) require.Error(t, err) require.Equal(t, 1, *statsCount) require.Equal(t, 0, *queryCount) @@ -342,8 +305,8 @@ func TestLogFilterTripperware(t *testing.T) { // Test MaxQuerierBytesRead limit statsCount, statsHandler = indexStatsResult(logproto.IndexStatsResponse{Bytes: 200}) queryCount, queryHandler = counter() - rt.setHandler(getQueryAndStatsHandler(queryHandler, statsHandler)) - _, err = tpw(rt).RoundTrip(req) + h = getQueryAndStatsHandler(queryHandler, statsHandler) + _, err = tpw.Wrap(h).Do(ctx, lreq) require.Error(t, err) require.Equal(t, 2, *statsCount) require.Equal(t, 0, *queryCount) @@ -362,14 +325,11 @@ func TestInstantQueryTripperware(t *testing.T) { queryTimeout: 1 * time.Minute, maxSeries: 1, } - tpw, stopper, err := NewTripperware(testShardingConfigNoCache, testEngineOpts, util_log.Logger, l, config.SchemaConfig{Configs: testSchemasTSDB}, nil, false, nil) + tpw, stopper, err := NewMiddleware(testShardingConfigNoCache, testEngineOpts, util_log.Logger, l, config.SchemaConfig{Configs: testSchemasTSDB}, nil, false, nil) if stopper != nil { defer stopper.Stop() } require.NoError(t, err) - rt, err := newfakeRoundTripper() - require.NoError(t, err) - defer rt.Close() lreq := &LokiInstantRequest{ Query: `sum by (job) (bytes_rate({cluster="dev-us-central-0"}[15m]))`, @@ -380,18 +340,12 @@ func TestInstantQueryTripperware(t *testing.T) { } ctx := user.InjectOrgID(context.Background(), "1") - req, err := DefaultCodec.EncodeRequest(ctx, lreq) - require.NoError(t, err) - - req = req.WithContext(ctx) - err = user.InjectOrgIDIntoHTTPRequest(ctx, req) - require.NoError(t, err) // Test MaxQueryBytesRead limit statsCount, statsHandler := indexStatsResult(logproto.IndexStatsResponse{Bytes: 2000}) queryCount, queryHandler := counter() - rt.setHandler(getQueryAndStatsHandler(queryHandler, statsHandler)) - _, err = tpw(rt).RoundTrip(req) + h := getQueryAndStatsHandler(queryHandler, statsHandler) + _, err = tpw.Wrap(h).Do(ctx, lreq) require.Error(t, err) require.Equal(t, 1, *statsCount) require.Equal(t, 0, *queryCount) @@ -399,33 +353,28 @@ func TestInstantQueryTripperware(t *testing.T) { // Test MaxQuerierBytesRead limit statsCount, statsHandler = indexStatsResult(logproto.IndexStatsResponse{Bytes: 200}) queryCount, queryHandler = counter() - rt.setHandler(getQueryAndStatsHandler(queryHandler, statsHandler)) - _, err = tpw(rt).RoundTrip(req) + h = getQueryAndStatsHandler(queryHandler, statsHandler) + _, err = tpw.Wrap(h).Do(ctx, lreq) require.Error(t, err) require.Equal(t, 2, *statsCount) require.Equal(t, 0, *queryCount) count, queryHandler := promqlResult(vector) _, statsHandler = indexStatsResult(logproto.IndexStatsResponse{Bytes: 10}) - rt.setHandler(getQueryAndStatsHandler(queryHandler, statsHandler)) - resp, err := tpw(rt).RoundTrip(req) + h = getQueryAndStatsHandler(queryHandler, statsHandler) + lokiResponse, err := tpw.Wrap(h).Do(ctx, lreq) require.Equal(t, 1, *count) require.NoError(t, err) - lokiResponse, err := DefaultCodec.DecodeResponse(ctx, resp, lreq) - require.NoError(t, err) require.IsType(t, &LokiPromResponse{}, lokiResponse) } func TestSeriesTripperware(t *testing.T) { - tpw, stopper, err := NewTripperware(testConfig, testEngineOpts, util_log.Logger, fakeLimits{maxQueryLength: 48 * time.Hour, maxQueryParallelism: 1}, config.SchemaConfig{Configs: testSchemas}, nil, false, nil) + tpw, stopper, err := NewMiddleware(testConfig, testEngineOpts, util_log.Logger, fakeLimits{maxQueryLength: 48 * time.Hour, maxQueryParallelism: 1}, config.SchemaConfig{Configs: testSchemas}, nil, false, nil) if stopper != nil { defer stopper.Stop() } require.NoError(t, err) - rt, err := newfakeRoundTripper() - require.NoError(t, err) - defer rt.Close() lreq := &LokiSeriesRequest{ Match: []string{`{job="varlogs"}`}, @@ -435,22 +384,15 @@ func TestSeriesTripperware(t *testing.T) { } ctx := user.InjectOrgID(context.Background(), "1") - req, err := DefaultCodec.EncodeRequest(ctx, lreq) - require.NoError(t, err) - req = req.WithContext(ctx) - err = user.InjectOrgIDIntoHTTPRequest(ctx, req) + count, h := seriesResult(series) + lokiSeriesResponse, err := tpw.Wrap(h).Do(ctx, lreq) require.NoError(t, err) - count, h := seriesResult(series) - rt.setHandler(h) - resp, err := tpw(rt).RoundTrip(req) // 2 queries require.Equal(t, 2, *count) - require.NoError(t, err) - lokiSeriesResponse, err := DefaultCodec.DecodeResponse(ctx, resp, lreq) res, ok := lokiSeriesResponse.(*LokiSeriesResponse) - require.Equal(t, true, ok) + require.True(t, ok) // make sure we return unique series since responses from // SplitByInterval middleware might have duplicate series @@ -459,14 +401,11 @@ func TestSeriesTripperware(t *testing.T) { } func TestLabelsTripperware(t *testing.T) { - tpw, stopper, err := NewTripperware(testConfig, testEngineOpts, util_log.Logger, fakeLimits{maxQueryLength: 48 * time.Hour, maxQueryParallelism: 1}, config.SchemaConfig{Configs: testSchemas}, nil, false, nil) + tpw, stopper, err := NewMiddleware(testConfig, testEngineOpts, util_log.Logger, fakeLimits{maxQueryLength: 48 * time.Hour, maxQueryParallelism: 1}, config.SchemaConfig{Configs: testSchemas}, nil, false, nil) if stopper != nil { defer stopper.Stop() } require.NoError(t, err) - rt, err := newfakeRoundTripper() - require.NoError(t, err) - defer rt.Close() lreq := NewLabelRequest( testTime.Add(-25*time.Hour), // bigger than the limit @@ -477,44 +416,42 @@ func TestLabelsTripperware(t *testing.T) { ) ctx := user.InjectOrgID(context.Background(), "1") - req, err := DefaultCodec.EncodeRequest(ctx, lreq) - require.NoError(t, err) - - req = req.WithContext(ctx) - err = user.InjectOrgIDIntoHTTPRequest(ctx, req) - require.NoError(t, err) handler := newFakeHandler( // we expect 2 calls. - http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - require.NoError(t, marshal.WriteLabelResponseJSON([]string{"foo", "bar", "blop"}, w)) + base.HandlerFunc(func(_ context.Context, _ base.Request) (base.Response, error) { + return &LokiLabelNamesResponse{ + Status: "success", + Data: []string{"foo", "bar", "blop"}, + Version: uint32(1), + }, nil }), - http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - require.NoError(t, marshal.WriteLabelResponseJSON([]string{"foo", "bar", "blip"}, w)) + base.HandlerFunc(func(_ context.Context, _ base.Request) (base.Response, error) { + return &LokiLabelNamesResponse{ + Status: "success", + Data: []string{"foo", "bar", "blip"}, + Version: uint32(1), + }, nil }), ) - rt.setHandler(handler) - resp, err := tpw(rt).RoundTrip(req) + lokiLabelsResponse, err := tpw.Wrap(handler).Do(ctx, lreq) + require.NoError(t, err) + // verify 2 calls have been made to downstream. require.Equal(t, 2, handler.count) - require.NoError(t, err) - lokiLabelsResponse, err := DefaultCodec.DecodeResponse(ctx, resp, lreq) res, ok := lokiLabelsResponse.(*LokiLabelNamesResponse) - require.Equal(t, true, ok) + require.True(t, ok) require.Equal(t, []string{"foo", "bar", "blop", "blip"}, res.Data) require.Equal(t, "success", res.Status) require.NoError(t, err) } func TestIndexStatsTripperware(t *testing.T) { - tpw, stopper, err := NewTripperware(testConfig, testEngineOpts, util_log.Logger, fakeLimits{maxQueryLength: 48 * time.Hour, maxQueryParallelism: 1}, config.SchemaConfig{Configs: testSchemas}, nil, false, nil) + tpw, stopper, err := NewMiddleware(testConfig, testEngineOpts, util_log.Logger, fakeLimits{maxQueryLength: 48 * time.Hour, maxQueryParallelism: 1}, config.SchemaConfig{Configs: testSchemas}, nil, false, nil) if stopper != nil { defer stopper.Stop() } require.NoError(t, err) - rt, err := newfakeRoundTripper() - require.NoError(t, err) - defer rt.Close() lreq := &logproto.IndexStatsRequest{ Matchers: `{job="varlogs"}`, @@ -523,12 +460,6 @@ func TestIndexStatsTripperware(t *testing.T) { } ctx := user.InjectOrgID(context.Background(), "1") - req, err := DefaultCodec.EncodeRequest(ctx, lreq) - require.NoError(t, err) - - req = req.WithContext(ctx) - err = user.InjectOrgIDIntoHTTPRequest(ctx, req) - require.NoError(t, err) response := logproto.IndexStatsResponse{ Streams: 100, @@ -538,8 +469,7 @@ func TestIndexStatsTripperware(t *testing.T) { } count, h := indexStatsResult(response) - rt.setHandler(h) - _, err = tpw(rt).RoundTrip(req) + _, err = tpw.Wrap(h).Do(ctx, lreq) // 2 queries require.Equal(t, 2, *count) require.NoError(t, err) @@ -547,14 +477,11 @@ func TestIndexStatsTripperware(t *testing.T) { // Test the cache. // It should have the answer already so the query handler shouldn't be hit count, h = indexStatsResult(response) - rt.setHandler(h) - resp, err := tpw(rt).RoundTrip(req) + indexStatsResponse, err := tpw.Wrap(h).Do(ctx, lreq) require.NoError(t, err) require.Equal(t, 0, *count) // Test the response is the expected - indexStatsResponse, err := DefaultCodec.DecodeResponse(ctx, resp, lreq) - require.NoError(t, err) res, ok := indexStatsResponse.(*IndexStatsResponse) require.Equal(t, true, ok) require.Equal(t, response.Streams*2, res.Response.Streams) @@ -565,47 +492,37 @@ func TestIndexStatsTripperware(t *testing.T) { func TestVolumeTripperware(t *testing.T) { t.Run("instant queries hardcode step to 0 and return a prometheus style vector response", func(t *testing.T) { - tpw, stopper, err := NewTripperware(testConfig, testEngineOpts, util_log.Logger, fakeLimits{maxQueryLength: 48 * time.Hour, volumeEnabled: true}, config.SchemaConfig{Configs: testSchemas}, nil, false, nil) + limits := fakeLimits{ + maxQueryLength: 48 * time.Hour, + volumeEnabled: true, + maxSeries: 42, + } + tpw, stopper, err := NewMiddleware(testConfig, testEngineOpts, util_log.Logger, limits, config.SchemaConfig{Configs: testSchemas}, nil, false, nil) if stopper != nil { defer stopper.Stop() } require.NoError(t, err) - rt, err := newfakeRoundTripper() - require.NoError(t, err) - defer rt.Close() - lreq := &logproto.VolumeRequest{ - Matchers: `{job="varlogs"}`, - From: model.TimeFromUnixNano(testTime.Add(-25 * time.Hour).UnixNano()), // bigger than split by interval limit - Through: model.TimeFromUnixNano(testTime.UnixNano()), - Limit: 10, - Step: 42, // this should be ignored and set to 0 + Matchers: `{job="varlogs"}`, + From: model.TimeFromUnixNano(testTime.Add(-25 * time.Hour).UnixNano()), // bigger than split by interval limit + Through: model.TimeFromUnixNano(testTime.UnixNano()), + Limit: 10, + Step: 0, // Travis/Trevor: this should be ignored and set to 0. Karsten: Why? + AggregateBy: seriesvolume.DefaultAggregateBy, } ctx := user.InjectOrgID(context.Background(), "1") - req, err := DefaultCodec.EncodeRequest(ctx, lreq) - require.NoError(t, err) - - req = req.WithContext(ctx) - err = user.InjectOrgIDIntoHTTPRequest(ctx, req) - require.NoError(t, err) - - req.URL.Path = "/loki/api/v1/index/volume" count, h := seriesVolumeResult(seriesVolume) - rt.setHandler(h) - resp, err := tpw(rt).RoundTrip(req) + volumeResp, err := tpw.Wrap(h).Do(ctx, lreq) require.NoError(t, err) require.Equal(t, 2, *count) // 2 queries from splitting - volumeResp, err := DefaultCodec.DecodeResponse(ctx, resp, nil) - require.NoError(t, err) - - expected := queryrangebase.PrometheusData{ + expected := base.PrometheusData{ ResultType: loghttp.ResultTypeVector, - Result: []queryrangebase.SampleStream{ + Result: []base.SampleStream{ { Labels: []logproto.LabelAdapter{{ Name: "bar", @@ -636,41 +553,29 @@ func TestVolumeTripperware(t *testing.T) { }) t.Run("range queries return a prometheus style metrics response, putting volumes in buckets based on the step", func(t *testing.T) { - tpw, stopper, err := NewTripperware(testConfig, testEngineOpts, util_log.Logger, fakeLimits{maxQueryLength: 48 * time.Hour, volumeEnabled: true}, config.SchemaConfig{Configs: testSchemas}, nil, false, nil) + tpw, stopper, err := NewMiddleware(testConfig, testEngineOpts, util_log.Logger, fakeLimits{maxQueryLength: 48 * time.Hour, volumeEnabled: true}, config.SchemaConfig{Configs: testSchemas}, nil, false, nil) if stopper != nil { defer stopper.Stop() } require.NoError(t, err) - rt, err := newfakeRoundTripper() - require.NoError(t, err) - defer rt.Close() - start := testTime.Add(-5 * time.Hour) end := testTime lreq := &logproto.VolumeRequest{ - Matchers: `{job="varlogs"}`, - From: model.TimeFromUnixNano(start.UnixNano()), // bigger than split by interval limit - Through: model.TimeFromUnixNano(end.UnixNano()), - Step: time.Hour.Milliseconds(), - Limit: 10, + Matchers: `{job="varlogs"}`, + From: model.TimeFromUnixNano(start.UnixNano()), // bigger than split by interval limit + Through: model.TimeFromUnixNano(end.UnixNano()), + Step: time.Hour.Milliseconds(), + Limit: 10, + AggregateBy: seriesvolume.DefaultAggregateBy, } ctx := user.InjectOrgID(context.Background(), "1") - req, err := DefaultCodec.EncodeRequest(ctx, lreq) - require.NoError(t, err) - - req = req.WithContext(ctx) - err = user.InjectOrgIDIntoHTTPRequest(ctx, req) - require.NoError(t, err) - - req.URL.Path = "/loki/api/v1/index/volume_range" count, h := seriesVolumeResult(seriesVolume) - rt.setHandler(h) - resp, err := tpw(rt).RoundTrip(req) + volumeResp, err := tpw.Wrap(h).Do(ctx, lreq) require.NoError(t, err) /* @@ -680,9 +585,6 @@ func TestVolumeTripperware(t *testing.T) { */ require.Equal(t, 6, *count) // 6 queries from splitting into step buckets - volumeResp, err := DefaultCodec.DecodeResponse(ctx, resp, nil) - require.NoError(t, err) - barBazExpectedSamples := []logproto.LegacySample{} util.ForInterval(time.Hour, start, end, true, func(s, _ time.Time) { barBazExpectedSamples = append(barBazExpectedSamples, logproto.LegacySample{ @@ -705,9 +607,9 @@ func TestVolumeTripperware(t *testing.T) { return fooBarExpectedSamples[i].TimestampMs < fooBarExpectedSamples[j].TimestampMs }) - expected := queryrangebase.PrometheusData{ + expected := base.PrometheusData{ ResultType: loghttp.ResultTypeMatrix, - Result: []queryrangebase.SampleStream{ + Result: []base.SampleStream{ { Labels: []logproto.LabelAdapter{{ Name: "bar", @@ -742,7 +644,7 @@ func TestNewTripperware_Caches(t *testing.T) { { name: "results cache disabled, stats cache disabled", config: Config{ - Config: queryrangebase.Config{ + Config: base.Config{ CacheResults: false, }, CacheIndexStatsResults: false, @@ -753,9 +655,9 @@ func TestNewTripperware_Caches(t *testing.T) { { name: "results cache enabled, stats cache disabled", config: Config{ - Config: queryrangebase.Config{ + Config: base.Config{ CacheResults: true, - ResultsCacheConfig: queryrangebase.ResultsCacheConfig{ + ResultsCacheConfig: base.ResultsCacheConfig{ CacheConfig: cache.Config{ EmbeddedCache: cache.EmbeddedCacheConfig{ MaxSizeMB: 1, @@ -772,9 +674,9 @@ func TestNewTripperware_Caches(t *testing.T) { { name: "results cache enabled, stats cache enabled", config: Config{ - Config: queryrangebase.Config{ + Config: base.Config{ CacheResults: true, - ResultsCacheConfig: queryrangebase.ResultsCacheConfig{ + ResultsCacheConfig: base.ResultsCacheConfig{ CacheConfig: cache.Config{ EmbeddedCache: cache.EmbeddedCacheConfig{ MaxSizeMB: 1, @@ -791,9 +693,9 @@ func TestNewTripperware_Caches(t *testing.T) { { name: "results cache enabled, stats cache enabled but different", config: Config{ - Config: queryrangebase.Config{ + Config: base.Config{ CacheResults: true, - ResultsCacheConfig: queryrangebase.ResultsCacheConfig{ + ResultsCacheConfig: base.ResultsCacheConfig{ CacheConfig: cache.Config{ EmbeddedCache: cache.EmbeddedCacheConfig{ Enabled: true, @@ -804,7 +706,7 @@ func TestNewTripperware_Caches(t *testing.T) { }, CacheIndexStatsResults: true, StatsCacheConfig: IndexStatsCacheConfig{ - ResultsCacheConfig: queryrangebase.ResultsCacheConfig{ + ResultsCacheConfig: base.ResultsCacheConfig{ CacheConfig: cache.Config{ EmbeddedCache: cache.EmbeddedCacheConfig{ Enabled: true, @@ -820,7 +722,7 @@ func TestNewTripperware_Caches(t *testing.T) { { name: "results cache enabled (no config provided)", config: Config{ - Config: queryrangebase.Config{ + Config: base.Config{ CacheResults: true, }, }, @@ -829,7 +731,7 @@ func TestNewTripperware_Caches(t *testing.T) { { name: "results cache disabled, stats cache enabled (no config provided)", config: Config{ - Config: queryrangebase.Config{ + Config: base.Config{ CacheResults: false, }, CacheIndexStatsResults: true, @@ -839,7 +741,7 @@ func TestNewTripperware_Caches(t *testing.T) { }, } { t.Run(tc.name, func(t *testing.T) { - _, stopper, err := NewTripperware(tc.config, testEngineOpts, util_log.Logger, fakeLimits{maxQueryLength: 48 * time.Hour, maxQueryParallelism: 1}, config.SchemaConfig{Configs: testSchemas}, nil, false, nil) + _, stopper, err := NewMiddleware(tc.config, testEngineOpts, util_log.Logger, fakeLimits{maxQueryLength: 48 * time.Hour, maxQueryParallelism: 1}, config.SchemaConfig{Configs: testSchemas}, nil, false, nil) if stopper != nil { defer stopper.Stop() } @@ -869,14 +771,11 @@ func TestNewTripperware_Caches(t *testing.T) { } func TestLogNoFilter(t *testing.T) { - tpw, stopper, err := NewTripperware(testConfig, testEngineOpts, util_log.Logger, fakeLimits{maxQueryParallelism: 1}, config.SchemaConfig{Configs: testSchemas}, nil, false, nil) + tpw, stopper, err := NewMiddleware(testConfig, testEngineOpts, util_log.Logger, fakeLimits{maxQueryParallelism: 1}, config.SchemaConfig{Configs: testSchemas}, nil, false, nil) if stopper != nil { defer stopper.Stop() } require.NoError(t, err) - rt, err := newfakeRoundTripper() - require.NoError(t, err) - defer rt.Close() lreq := &LokiRequest{ Query: `{app="foo"}`, @@ -888,126 +787,44 @@ func TestLogNoFilter(t *testing.T) { } ctx := user.InjectOrgID(context.Background(), "1") - req, err := DefaultCodec.EncodeRequest(ctx, lreq) - require.NoError(t, err) - - req = req.WithContext(ctx) - err = user.InjectOrgIDIntoHTTPRequest(ctx, req) - require.NoError(t, err) count, h := promqlResult(streams) - rt.setHandler(h) - _, err = tpw(rt).RoundTrip(req) + _, err = tpw.Wrap(h).Do(ctx, lreq) require.Equal(t, 1, *count) require.Nil(t, err) } -func TestRegexpParamsSupport(t *testing.T) { - l := WithSplitByLimits(fakeLimits{maxSeries: 1, maxQueryParallelism: 2}, 4*time.Hour) - tpw, stopper, err := NewTripperware(testConfig, testEngineOpts, util_log.Logger, l, config.SchemaConfig{Configs: testSchemas}, nil, false, nil) - if stopper != nil { - defer stopper.Stop() - } - require.NoError(t, err) - rt, err := newfakeRoundTripper() - require.NoError(t, err) - defer rt.Close() - - lreq := &LokiRequest{ - Query: `{app="foo"}`, - Limit: 1000, - StartTs: testTime.Add(-6 * time.Hour), - EndTs: testTime, - Direction: logproto.FORWARD, - Path: "/loki/api/v1/query_range", - } - - ctx := user.InjectOrgID(context.Background(), "1") - req, err := DefaultCodec.EncodeRequest(ctx, lreq) - require.NoError(t, err) - - // fudge a regexp params - params := req.URL.Query() - params.Set("regexp", "foo") - req.URL.RawQuery = params.Encode() - - req = req.WithContext(ctx) - err = user.InjectOrgIDIntoHTTPRequest(ctx, req) - require.NoError(t, err) - - count, h := promqlResult(streams) - rt.setHandler(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { - // the query params should contain the filter. - require.Contains(t, r.URL.Query().Get("query"), `|~ "foo"`) - h.ServeHTTP(rw, r) - })) - _, err = tpw(rt).RoundTrip(req) - require.Equal(t, 2, *count) // expecting the query to also be splitted since it has a filter. - require.NoError(t, err) -} - func TestPostQueries(t *testing.T) { - req, err := http.NewRequest(http.MethodPost, "/loki/api/v1/query_range", nil) - data := url.Values{ - "query": {`{app="foo"} |~ "foo"`}, - } - body := bytes.NewBufferString(data.Encode()) - req.Body = io.NopCloser(body) - req.Header.Add("Content-Type", "application/x-www-form-urlencoded") - req.Header.Add("Content-Length", strconv.Itoa(len(data.Encode()))) - req = req.WithContext(user.InjectOrgID(context.Background(), "1")) - require.NoError(t, err) - _, err = newRoundTripper( + lreq := &LokiRequest{Query: `{app="foo"} |~ "foo"`} + ctx := user.InjectOrgID(context.Background(), "1") + handler := base.HandlerFunc(func(context.Context, base.Request) (base.Response, error) { + t.Error("unexpected default roundtripper called") + return nil, nil + }) + _, err := newRoundTripper( util_log.Logger, - queryrangebase.RoundTripFunc(func(*http.Request) (*http.Response, error) { - t.Error("unexpected default roundtripper called") - return nil, nil - }), - queryrangebase.RoundTripFunc(func(*http.Request) (*http.Response, error) { - t.Error("unexpected default roundtripper called") - return nil, nil - }), - queryrangebase.RoundTripFunc(func(*http.Request) (*http.Response, error) { - return nil, nil - }), - queryrangebase.RoundTripFunc(func(*http.Request) (*http.Response, error) { - t.Error("unexpected metric roundtripper called") - return nil, nil - }), - queryrangebase.RoundTripFunc(func(*http.Request) (*http.Response, error) { - t.Error("unexpected series roundtripper called") - return nil, nil - }), - queryrangebase.RoundTripFunc(func(*http.Request) (*http.Response, error) { - t.Error("unexpected labels roundtripper called") - return nil, nil - }), - queryrangebase.RoundTripFunc(func(*http.Request) (*http.Response, error) { - t.Error("unexpected instant roundtripper called") - return nil, nil - }), - queryrangebase.RoundTripFunc(func(*http.Request) (*http.Response, error) { - t.Error("unexpected indexStats roundtripper called") - return nil, nil - }), - queryrangebase.RoundTripFunc(func(*http.Request) (*http.Response, error) { - t.Error("unexpected labelVolume roundtripper called") + handler, + handler, + base.HandlerFunc(func(context.Context, base.Request) (base.Response, error) { return nil, nil }), + handler, + handler, + handler, + handler, + handler, + handler, fakeLimits{}, - ).RoundTrip(req) + ).Do(ctx, lreq) require.NoError(t, err) } func TestTripperware_EntriesLimit(t *testing.T) { - tpw, stopper, err := NewTripperware(testConfig, testEngineOpts, util_log.Logger, fakeLimits{maxEntriesLimitPerQuery: 5000, maxQueryParallelism: 1}, config.SchemaConfig{Configs: testSchemas}, nil, false, nil) + tpw, stopper, err := NewMiddleware(testConfig, testEngineOpts, util_log.Logger, fakeLimits{maxEntriesLimitPerQuery: 5000, maxQueryParallelism: 1}, config.SchemaConfig{Configs: testSchemas}, nil, false, nil) if stopper != nil { defer stopper.Stop() } require.NoError(t, err) - rt, err := newfakeRoundTripper() - require.NoError(t, err) - defer rt.Close() lreq := &LokiRequest{ Query: `{app="foo"}`, @@ -1019,15 +836,16 @@ func TestTripperware_EntriesLimit(t *testing.T) { } ctx := user.InjectOrgID(context.Background(), "1") - req, err := DefaultCodec.EncodeRequest(ctx, lreq) - require.NoError(t, err) - req = req.WithContext(ctx) - err = user.InjectOrgIDIntoHTTPRequest(ctx, req) - require.NoError(t, err) + called := false + h := base.HandlerFunc(func(context.Context, base.Request) (base.Response, error) { + called = true + return nil, nil + }) - _, err = tpw(rt).RoundTrip(req) + _, err = tpw.Wrap(h).Do(ctx, lreq) require.Equal(t, httpgrpc.Errorf(http.StatusBadRequest, "max entries limit per query exceeded, limit > max_entries_limit (10000 > 5000)"), err) + require.False(t, called) } func TestTripperware_RequiredLabels(t *testing.T) { @@ -1048,16 +866,12 @@ func TestTripperware_RequiredLabels(t *testing.T) { } { t.Run(test.qs, func(t *testing.T) { limits := fakeLimits{maxEntriesLimitPerQuery: 5000, maxQueryParallelism: 1, requiredLabels: []string{"app"}} - tpw, stopper, err := NewTripperware(testConfig, testEngineOpts, util_log.Logger, limits, config.SchemaConfig{Configs: testSchemas}, nil, false, nil) + tpw, stopper, err := NewMiddleware(testConfig, testEngineOpts, util_log.Logger, limits, config.SchemaConfig{Configs: testSchemas}, nil, false, nil) if stopper != nil { defer stopper.Stop() } require.NoError(t, err) - rt, err := newfakeRoundTripper() - require.NoError(t, err) - defer rt.Close() _, h := promqlResult(test.response) - rt.setHandler(h) lreq := &LokiRequest{ Query: test.qs, @@ -1067,16 +881,13 @@ func TestTripperware_RequiredLabels(t *testing.T) { Direction: logproto.FORWARD, Path: "/loki/api/v1/query_range", } + // See loghttp.step + step := time.Duration(int(math.Max(math.Floor(lreq.EndTs.Sub(lreq.StartTs).Seconds()/250), 1))) * time.Second + lreq.Step = step.Milliseconds() ctx := user.InjectOrgID(context.Background(), "1") - req, err := DefaultCodec.EncodeRequest(ctx, lreq) - require.NoError(t, err) - - req = req.WithContext(ctx) - err = user.InjectOrgIDIntoHTTPRequest(ctx, req) - require.NoError(t, err) - _, err = tpw(rt).RoundTrip(req) + _, err = tpw.Wrap(h).Do(ctx, lreq) if test.expectedError != "" { require.Equal(t, httpgrpc.Errorf(http.StatusBadRequest, test.expectedError), err) } else { @@ -1159,17 +970,13 @@ func TestTripperware_RequiredNumberLabels(t *testing.T) { maxQueryParallelism: 1, requiredNumberLabels: tc.requiredNumberLabels, } - tpw, stopper, err := NewTripperware(testConfig, testEngineOpts, util_log.Logger, limits, config.SchemaConfig{Configs: testSchemas}, nil, false, nil) + tpw, stopper, err := NewMiddleware(testConfig, testEngineOpts, util_log.Logger, limits, config.SchemaConfig{Configs: testSchemas}, nil, false, nil) if stopper != nil { defer stopper.Stop() } require.NoError(t, err) - rt, err := newfakeRoundTripper() - require.NoError(t, err) - defer rt.Close() _, h := promqlResult(tc.response) - rt.setHandler(h) lreq := &LokiRequest{ Query: tc.query, @@ -1179,16 +986,13 @@ func TestTripperware_RequiredNumberLabels(t *testing.T) { Direction: logproto.FORWARD, Path: "/loki/api/v1/query_range", } + // See loghttp.step + step := time.Duration(int(math.Max(math.Floor(lreq.EndTs.Sub(lreq.StartTs).Seconds()/250), 1))) * time.Second + lreq.Step = step.Milliseconds() ctx := user.InjectOrgID(context.Background(), "1") - req, err := DefaultCodec.EncodeRequest(ctx, lreq) - require.NoError(t, err) - req = req.WithContext(ctx) - err = user.InjectOrgIDIntoHTTPRequest(ctx, req) - require.NoError(t, err) - - _, err = tpw(rt).RoundTrip(req) + _, err = tpw.Wrap(h).Do(ctx, lreq) if tc.expectedError != noErr { require.Equal(t, httpgrpc.Errorf(http.StatusBadRequest, tc.expectedError), err) } else { @@ -1284,7 +1088,7 @@ func TestMetricsTripperware_SplitShardStats(t *testing.T) { for _, tc := range []struct { name string - request queryrangebase.Request + request base.Request expectedSplitStats int64 expectedShardStats int64 }{ @@ -1342,30 +1146,16 @@ func TestMetricsTripperware_SplitShardStats(t *testing.T) { }, } { t.Run(tc.name, func(t *testing.T) { - tpw, stopper, err := NewTripperware(statsTestCfg, testEngineOpts, util_log.Logger, l, config.SchemaConfig{Configs: statsSchemas}, nil, false, nil) + tpw, stopper, err := NewMiddleware(statsTestCfg, testEngineOpts, util_log.Logger, l, config.SchemaConfig{Configs: statsSchemas}, nil, false, nil) if stopper != nil { defer stopper.Stop() } require.NoError(t, err) ctx := user.InjectOrgID(context.Background(), "1") - req, err := DefaultCodec.EncodeRequest(ctx, tc.request) - require.NoError(t, err) - - req = req.WithContext(ctx) - err = user.InjectOrgIDIntoHTTPRequest(ctx, req) - require.NoError(t, err) - - rt, err := newfakeRoundTripper() - require.NoError(t, err) - defer rt.Close() _, h := promqlResult(matrix) - rt.setHandler(h) - resp, err := tpw(rt).RoundTrip(req) - require.NoError(t, err) - - lokiResponse, err := DefaultCodec.DecodeResponse(ctx, resp, tc.request) + lokiResponse, err := tpw.Wrap(h).Do(ctx, tc.request) require.NoError(t, err) require.Equal(t, tc.expectedSplitStats, lokiResponse.(*LokiPromResponse).Statistics.Summary.Splits) @@ -1474,110 +1264,97 @@ func (f fakeLimits) TSDBMaxBytesPerShard(_ string) int { return valid.DefaultTSDBMaxBytesPerShard } -func counter() (*int, http.Handler) { +func counter() (*int, base.Handler) { count := 0 var lock sync.Mutex - return &count, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + return &count, base.HandlerFunc(func(ctx context.Context, r base.Request) (base.Response, error) { lock.Lock() defer lock.Unlock() count++ + return base.NewEmptyPrometheusResponse(), nil }) } -func promqlResult(v parser.Value) (*int, http.Handler) { +func counterWithError(err error) (*int, base.Handler) { count := 0 var lock sync.Mutex - return &count, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + return &count, base.HandlerFunc(func(ctx context.Context, r base.Request) (base.Response, error) { lock.Lock() defer lock.Unlock() - if err := marshal.WriteQueryResponseJSON(v, stats.Result{}, w); err != nil { - panic(err) - } count++ + return nil, err }) } -func seriesResult(v logproto.SeriesResponse) (*int, http.Handler) { +func promqlResult(v parser.Value) (*int, base.Handler) { count := 0 var lock sync.Mutex - return &count, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + return &count, base.HandlerFunc(func(ctx context.Context, r base.Request) (base.Response, error) { lock.Lock() defer lock.Unlock() - if err := marshal.WriteSeriesResponseJSON(v.GetSeries(), w); err != nil { - panic(err) + count++ + params, err := ParamsFromRequest(r) + if err != nil { + return nil, err } + result := logqlmodel.Result{Data: v} + return ResultToResponse(result, params) + }) +} + +func seriesResult(v logproto.SeriesResponse) (*int, base.Handler) { + count := 0 + var lock sync.Mutex + return &count, base.HandlerFunc(func(ctx context.Context, r base.Request) (base.Response, error) { + lock.Lock() + defer lock.Unlock() count++ + return &LokiSeriesResponse{ + Status: "success", + Version: 1, + Data: v.Series, + }, nil }) } -func indexStatsResult(v logproto.IndexStatsResponse) (*int, http.Handler) { +func indexStatsResult(v logproto.IndexStatsResponse) (*int, base.Handler) { count := 0 var lock sync.Mutex - return &count, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + return &count, base.HandlerFunc(func(_ context.Context, _ base.Request) (base.Response, error) { lock.Lock() defer lock.Unlock() - if err := marshal.WriteIndexStatsResponseJSON(&v, w); err != nil { - panic(err) - } count++ + return &IndexStatsResponse{Response: &v}, nil }) } -func seriesVolumeResult(v logproto.VolumeResponse) (*int, http.Handler) { +func seriesVolumeResult(v logproto.VolumeResponse) (*int, base.Handler) { count := 0 var lock sync.Mutex - return &count, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + return &count, base.HandlerFunc(func(_ context.Context, _ base.Request) (base.Response, error) { lock.Lock() defer lock.Unlock() - if err := marshal.WriteVolumeResponseJSON(&v, w); err != nil { - panic(err) - } count++ + return &VolumeResponse{Response: &v}, nil }) } type fakeHandler struct { count int lock sync.Mutex - calls []http.Handler + calls []base.Handler } -func newFakeHandler(calls ...http.Handler) *fakeHandler { +func newFakeHandler(calls ...base.Handler) *fakeHandler { return &fakeHandler{calls: calls} } -func (f *fakeHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { +func (f *fakeHandler) Do(ctx context.Context, req base.Request) (base.Response, error) { f.lock.Lock() defer f.lock.Unlock() - f.calls[f.count].ServeHTTP(w, req) + r, err := f.calls[f.count].Do(ctx, req) f.count++ -} - -type fakeRoundTripper struct { - *httptest.Server - host string -} - -func newfakeRoundTripper() (*fakeRoundTripper, error) { - s := httptest.NewServer(nil) - u, err := url.Parse(s.URL) - if err != nil { - return nil, err - } - return &fakeRoundTripper{ - Server: s, - host: u.Host, - }, nil -} - -func (s *fakeRoundTripper) setHandler(h http.Handler) { - s.Config.Handler = middleware.AuthenticateUser.Wrap(h) -} - -func (s fakeRoundTripper) RoundTrip(r *http.Request) (*http.Response, error) { - r.URL.Scheme = "http" - r.URL.Host = s.host - return http.DefaultTransport.RoundTrip(r) + return r, err } func toMs(t time.Time) int64 { diff --git a/pkg/querier/queryrange/volume.go b/pkg/querier/queryrange/volume.go index 44b88a1d907ca..305397ff6d6e0 100644 --- a/pkg/querier/queryrange/volume.go +++ b/pkg/querier/queryrange/volume.go @@ -2,13 +2,10 @@ package queryrange import ( "context" - "net/http" "sort" "time" "github.com/grafana/dskit/concurrency" - "github.com/grafana/dskit/httpgrpc" - "github.com/grafana/dskit/user" "github.com/prometheus/common/model" "github.com/prometheus/prometheus/model/labels" @@ -22,26 +19,6 @@ import ( "github.com/grafana/loki/pkg/util" ) -func VolumeDownstreamHandler(nextRT http.RoundTripper, codec queryrangebase.Codec) queryrangebase.Handler { - return queryrangebase.HandlerFunc(func(ctx context.Context, req queryrangebase.Request) (queryrangebase.Response, error) { - request, err := codec.EncodeRequest(ctx, req) - if err != nil { - return nil, err - } - - if err := user.InjectOrgIDIntoHTTPRequest(ctx, request); err != nil { - return nil, httpgrpc.Errorf(http.StatusBadRequest, err.Error()) - } - - resp, err := nextRT.RoundTrip(request) - if err != nil { - return nil, err - } - - return codec.DecodeResponse(ctx, resp, req) - }) -} - func NewVolumeMiddleware() queryrangebase.Middleware { return queryrangebase.MiddlewareFunc(func(next queryrangebase.Handler) queryrangebase.Handler { return queryrangebase.HandlerFunc(func(ctx context.Context, req queryrangebase.Request) (queryrangebase.Response, error) { diff --git a/pkg/util/limiter/combined_limits.go b/pkg/util/limiter/combined_limits.go index 59f0b6dec3a49..40d6fd508a4d4 100644 --- a/pkg/util/limiter/combined_limits.go +++ b/pkg/util/limiter/combined_limits.go @@ -5,8 +5,8 @@ import ( "github.com/grafana/loki/pkg/compactor" "github.com/grafana/loki/pkg/distributor" "github.com/grafana/loki/pkg/ingester" - "github.com/grafana/loki/pkg/querier" - "github.com/grafana/loki/pkg/querier/queryrange" + querier_limits "github.com/grafana/loki/pkg/querier/limits" + queryrange_limits "github.com/grafana/loki/pkg/querier/queryrange/limits" "github.com/grafana/loki/pkg/ruler" "github.com/grafana/loki/pkg/scheduler" "github.com/grafana/loki/pkg/storage" @@ -17,8 +17,8 @@ type CombinedLimits interface { compactor.Limits distributor.Limits ingester.Limits - querier.Limits - queryrange.Limits + querier_limits.Limits + queryrange_limits.Limits ruler.RulesLimits scheduler.Limits storage.StoreLimits diff --git a/pkg/util/querylimits/propagation.go b/pkg/util/querylimits/propagation.go index 75cce84e870dd..f0e5fbc8f6b49 100644 --- a/pkg/util/querylimits/propagation.go +++ b/pkg/util/querylimits/propagation.go @@ -44,14 +44,19 @@ func MarshalQueryLimits(limits *QueryLimits) ([]byte, error) { // InjectQueryLimitsHTTP adds the query limits to the request headers. func InjectQueryLimitsHTTP(r *http.Request, limits *QueryLimits) error { + return InjectQueryLimitsHeader(&r.Header, limits) +} + +// InjectQueryLimitsHeader adds the query limits to the headers. +func InjectQueryLimitsHeader(h *http.Header, limits *QueryLimits) error { // Ensure any existing policy sets are erased - r.Header.Del(HTTPHeaderQueryLimitsKey) + h.Del(HTTPHeaderQueryLimitsKey) encodedLimits, err := MarshalQueryLimits(limits) if err != nil { return err } - r.Header.Add(HTTPHeaderQueryLimitsKey, string(encodedLimits)) + h.Add(HTTPHeaderQueryLimitsKey, string(encodedLimits)) return nil } diff --git a/pkg/util/querylimits/tripperware.go b/pkg/util/querylimits/tripperware.go deleted file mode 100644 index a7608b98951b7..0000000000000 --- a/pkg/util/querylimits/tripperware.go +++ /dev/null @@ -1,51 +0,0 @@ -package querylimits - -import ( - "net/http" - - "github.com/grafana/loki/pkg/querier/queryrange/queryrangebase" -) - -type tripperwareWrapper struct { - next http.RoundTripper - wrapped http.RoundTripper -} - -// WrapTripperware wraps the existing tripperware to make sure the query limit policy headers are propagated -func WrapTripperware(existing queryrangebase.Tripperware) queryrangebase.Tripperware { - return func(next http.RoundTripper) http.RoundTripper { - limitsTrw := &tripperwareWrapper{ - next: next, - } - limitsTrw.wrapped = existing(queryrangebase.RoundTripFunc(limitsTrw.PostWrappedRoundTrip)) - return limitsTrw - } -} - -func (t *tripperwareWrapper) RoundTrip(r *http.Request) (*http.Response, error) { - ctx := r.Context() - - limits := ExtractQueryLimitsContext(ctx) - - if limits != nil { - ctx = InjectQueryLimitsContext(ctx, *limits) - r = r.Clone(ctx) - } - - return t.wrapped.RoundTrip(r) -} - -func (t *tripperwareWrapper) PostWrappedRoundTrip(r *http.Request) (*http.Response, error) { - ctx := r.Context() - - limits := ExtractQueryLimitsContext(ctx) - - if limits != nil { - err := InjectQueryLimitsHTTP(r, limits) - if err != nil { - return nil, err - } - } - - return t.next.RoundTrip(r) -}