From c127d3040bdcee0f3a5275c91ac428e95a1ff21e Mon Sep 17 00:00:00 2001 From: Karsten Jeschkies Date: Wed, 1 Nov 2023 18:50:02 +0100 Subject: [PATCH] Check switches on syntax.Expr for exhaustivness. --- .golangci.yml | 1 + pkg/logql/engine.go | 2 +- pkg/logql/evaluator.go | 2 +- pkg/logql/optimize.go | 4 ++-- pkg/logql/rangemapper.go | 8 ++++---- pkg/logql/syntax/ast.go | 4 +++- pkg/logql/syntax/walk.go | 2 +- pkg/logql/syntax/walk_test.go | 4 ++-- pkg/querier/multi_tenant_querier.go | 2 +- pkg/querier/queryrange/split_by_interval.go | 2 +- 10 files changed, 17 insertions(+), 14 deletions(-) diff --git a/.golangci.yml b/.golangci.yml index f1d4093919a59..a835d27d6fa88 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -79,6 +79,7 @@ linters: - goimports - gosimple - staticcheck + - gochecksumtype disable: - unused - unparam diff --git a/pkg/logql/engine.go b/pkg/logql/engine.go index 1edf86da3ed58..038db129f35b6 100644 --- a/pkg/logql/engine.go +++ b/pkg/logql/engine.go @@ -430,7 +430,7 @@ func (q *query) evalSample(ctx context.Context, expr syntax.SampleExpr) (promql_ func (q *query) checkIntervalLimit(expr syntax.SampleExpr, limit time.Duration) error { var err error - expr.Walk(func(e interface{}) { + expr.Walk(func(e syntax.Expr) { switch e := e.(type) { case *syntax.RangeAggregationExpr: if e.Left == nil || e.Left.Interval <= limit { diff --git a/pkg/logql/evaluator.go b/pkg/logql/evaluator.go index 07a056e4ffe58..153b8c9a5f330 100644 --- a/pkg/logql/evaluator.go +++ b/pkg/logql/evaluator.go @@ -112,7 +112,7 @@ func Sortable(q Params) (bool, error) { if err != nil { return false, err } - expr.Walk(func(e interface{}) { + expr.Walk(func(e syntax.Expr) { rangeExpr, ok := e.(*syntax.VectorAggregationExpr) if !ok { return diff --git a/pkg/logql/optimize.go b/pkg/logql/optimize.go index de15bce40e200..1f00153e18b87 100644 --- a/pkg/logql/optimize.go +++ b/pkg/logql/optimize.go @@ -6,7 +6,7 @@ import "github.com/grafana/loki/pkg/logql/syntax" func optimizeSampleExpr(expr syntax.SampleExpr) (syntax.SampleExpr, error) { var skip bool // we skip sharding AST for now, it's not easy to clone them since they are not part of the language. - expr.Walk(func(e interface{}) { + expr.Walk(func(e syntax.Expr) { switch e.(type) { case *ConcatSampleExpr, *DownstreamSampleExpr: skip = true @@ -28,7 +28,7 @@ func optimizeSampleExpr(expr syntax.SampleExpr) (syntax.SampleExpr, error) { // removeLineformat removes unnecessary line_format within a SampleExpr. func removeLineformat(expr syntax.SampleExpr) { - expr.Walk(func(e interface{}) { + expr.Walk(func(e syntax.Expr) { rangeExpr, ok := e.(*syntax.RangeAggregationExpr) if !ok { return diff --git a/pkg/logql/rangemapper.go b/pkg/logql/rangemapper.go index 4cb3965ee4910..cc63944bc07e9 100644 --- a/pkg/logql/rangemapper.go +++ b/pkg/logql/rangemapper.go @@ -177,7 +177,7 @@ func (m RangeMapper) Map(expr syntax.SampleExpr, vectorAggrPushdown *syntax.Vect // Example: expression `count_over_time({app="foo"}[10m])` returns 10m func getRangeInterval(expr syntax.SampleExpr) time.Duration { var rangeInterval time.Duration - expr.Walk(func(e interface{}) { + expr.Walk(func(e syntax.Expr) { switch concrete := e.(type) { case *syntax.RangeAggregationExpr: rangeInterval = concrete.Left.Interval @@ -190,7 +190,7 @@ func getRangeInterval(expr syntax.SampleExpr) time.Duration { // such as `| json` or `| logfmt`, that would result in an exploding amount of series in downstream queries. func hasLabelExtractionStage(expr syntax.SampleExpr) bool { found := false - expr.Walk(func(e interface{}) { + expr.Walk(func(e syntax.Expr) { switch concrete := e.(type) { case *syntax.LogfmtParserExpr: found = true @@ -278,7 +278,7 @@ func (m RangeMapper) vectorAggrWithRangeDownstreams(expr *syntax.RangeAggregatio // Returns the updated downstream ConcatSampleExpr. func appendDownstream(downstreams *ConcatSampleExpr, expr syntax.SampleExpr, interval time.Duration, offset time.Duration) *ConcatSampleExpr { sampleExpr := clone(expr) - sampleExpr.Walk(func(e interface{}) { + sampleExpr.Walk(func(e syntax.Expr) { switch concrete := e.(type) { case *syntax.RangeAggregationExpr: concrete.Left.Interval = interval @@ -300,7 +300,7 @@ func getOffsets(expr syntax.SampleExpr) []time.Duration { // Expect to always find at most 1 offset, so preallocate it accordingly offsets := make([]time.Duration, 0, 1) - expr.Walk(func(e interface{}) { + expr.Walk(func(e syntax.Expr) { switch concrete := e.(type) { case *syntax.RangeAggregationExpr: offsets = append(offsets, concrete.Left.Offset) diff --git a/pkg/logql/syntax/ast.go b/pkg/logql/syntax/ast.go index aa4aa7fa80617..babb92d6c27b8 100644 --- a/pkg/logql/syntax/ast.go +++ b/pkg/logql/syntax/ast.go @@ -22,6 +22,8 @@ import ( ) // Expr is the root expression which can be a SampleExpr or LogSelectorExpr +// +//sumtype:decl type Expr interface { logQLExpr() // ensure it's not implemented accidentally Shardable() bool // A recursive check on the AST to see if it's shardable. @@ -2106,7 +2108,7 @@ func (e *VectorExpr) MatcherGroups() ([]MatcherRange, error) { return nil, e.er func (e *VectorExpr) Extractor() (log.SampleExtractor, error) { return nil, nil } func ReducesLabels(e Expr) (conflict bool) { - e.Walk(func(e interface{}) { + e.Walk(func(e Expr) { switch expr := e.(type) { case *RangeAggregationExpr: if groupingReducesLabels(expr.Grouping) { diff --git a/pkg/logql/syntax/walk.go b/pkg/logql/syntax/walk.go index 3a8b85d92d0b2..291ec8b31036f 100644 --- a/pkg/logql/syntax/walk.go +++ b/pkg/logql/syntax/walk.go @@ -1,6 +1,6 @@ package syntax -type WalkFn = func(e interface{}) +type WalkFn = func(e Expr) func walkAll(f WalkFn, xs ...Walkable) { for _, x := range xs { diff --git a/pkg/logql/syntax/walk_test.go b/pkg/logql/syntax/walk_test.go index 678e89df99c48..3350515b9c461 100644 --- a/pkg/logql/syntax/walk_test.go +++ b/pkg/logql/syntax/walk_test.go @@ -32,7 +32,7 @@ func Test_Walkable(t *testing.T) { require.Nil(t, err) var cnt int - expr.Walk(func(_ interface{}) { cnt++ }) + expr.Walk(func(_ Expr) { cnt++ }) require.Equal(t, test.want, cnt) }) } @@ -77,7 +77,7 @@ func Test_AppendMatchers(t *testing.T) { expr, err := ParseExpr(test.expr) require.NoError(t, err) - expr.Walk(func(e interface{}) { + expr.Walk(func(e Expr) { switch me := e.(type) { case *MatchersExpr: me.AppendMatchers(test.matchers) diff --git a/pkg/querier/multi_tenant_querier.go b/pkg/querier/multi_tenant_querier.go index c9b1b56b8b284..f4881df48a6d7 100644 --- a/pkg/querier/multi_tenant_querier.go +++ b/pkg/querier/multi_tenant_querier.go @@ -227,7 +227,7 @@ func removeTenantSelector(params logql.SelectSampleParams, tenantIDs []string) ( // replaceMatchers traverses the passed expression and replaces all matchers. func replaceMatchers(expr syntax.Expr, matchers []*labels.Matcher) syntax.Expr { expr, _ = syntax.Clone(expr) - expr.Walk(func(e interface{}) { + expr.Walk(func(e syntax.Expr) { switch concrete := e.(type) { case *syntax.MatchersExpr: concrete.Mts = matchers diff --git a/pkg/querier/queryrange/split_by_interval.go b/pkg/querier/queryrange/split_by_interval.go index f3f2c13d60042..da8326a678ec5 100644 --- a/pkg/querier/queryrange/split_by_interval.go +++ b/pkg/querier/queryrange/split_by_interval.go @@ -322,7 +322,7 @@ func maxRangeVectorAndOffsetDuration(q string) (time.Duration, time.Duration, er } var maxRVDuration, maxOffset time.Duration - expr.Walk(func(e interface{}) { + expr.Walk(func(e syntax.Expr) { if r, ok := e.(*syntax.LogRange); ok { if r.Interval > maxRVDuration { maxRVDuration = r.Interval