From da8cc13506f99b310ea6505b1ee319bb9b22dde8 Mon Sep 17 00:00:00 2001 From: Owen Mansel-Chan Date: Mon, 11 Mar 2024 09:57:02 +0000 Subject: [PATCH 1/2] go extractor: avoid long string concatenations When we see "a" + "b" + "c" + "d", do not add a row to the constvalues table for the intermiediate strings "ab" and "abc". We still have entries for the string literals ("a", "b", "c", and "d") and the whole string concatenation ("abcd"). --- go/extractor/extractor.go | 142 ++++++++++++++++++++++---------------- 1 file changed, 83 insertions(+), 59 deletions(-) diff --git a/go/extractor/extractor.go b/go/extractor/extractor.go index f2ba68a20f0e..97f3ed743fb1 100644 --- a/go/extractor/extractor.go +++ b/go/extractor/extractor.go @@ -794,7 +794,7 @@ func extractLocalScope(tw *trap.Writer, scope *types.Scope, parentScopeLabel tra func extractFileNode(tw *trap.Writer, nd *ast.File) { lbl := tw.Labeler.FileLabel() - extractExpr(tw, nd.Name, lbl, 0) + extractExpr(tw, nd.Name, lbl, 0, false) for i, decl := range nd.Decls { extractDecl(tw, decl, lbl, i) @@ -851,7 +851,7 @@ func emitScopeNodeInfo(tw *trap.Writer, nd ast.Node, lbl trap.Label) { } // extractExpr extracts AST information for the given expression and all its subexpressions -func extractExpr(tw *trap.Writer, expr ast.Expr, parent trap.Label, idx int) { +func extractExpr(tw *trap.Writer, expr ast.Expr, parent trap.Label, idx int, skipExtractingValue bool) { if expr == nil { return } @@ -900,7 +900,7 @@ func extractExpr(tw *trap.Writer, expr ast.Expr, parent trap.Label, idx int) { return } kind = dbscheme.EllipsisExpr.Index() - extractExpr(tw, expr.Elt, lbl, 0) + extractExpr(tw, expr.Elt, lbl, 0, false) case *ast.BasicLit: if expr == nil { return @@ -932,28 +932,28 @@ func extractExpr(tw *trap.Writer, expr ast.Expr, parent trap.Label, idx int) { return } kind = dbscheme.FuncLitExpr.Index() - extractExpr(tw, expr.Type, lbl, 0) + extractExpr(tw, expr.Type, lbl, 0, false) extractStmt(tw, expr.Body, lbl, 1) case *ast.CompositeLit: if expr == nil { return } kind = dbscheme.CompositeLitExpr.Index() - extractExpr(tw, expr.Type, lbl, 0) + extractExpr(tw, expr.Type, lbl, 0, false) extractExprs(tw, expr.Elts, lbl, 1, 1) case *ast.ParenExpr: if expr == nil { return } kind = dbscheme.ParenExpr.Index() - extractExpr(tw, expr.X, lbl, 0) + extractExpr(tw, expr.X, lbl, 0, false) case *ast.SelectorExpr: if expr == nil { return } kind = dbscheme.SelectorExpr.Index() - extractExpr(tw, expr.X, lbl, 0) - extractExpr(tw, expr.Sel, lbl, 1) + extractExpr(tw, expr.X, lbl, 0, false) + extractExpr(tw, expr.Sel, lbl, 1, false) case *ast.IndexExpr: if expr == nil { return @@ -974,8 +974,8 @@ func extractExpr(tw *trap.Writer, expr ast.Expr, parent trap.Label, idx int) { kind = dbscheme.IndexExpr.Index() } } - extractExpr(tw, expr.X, lbl, 0) - extractExpr(tw, expr.Index, lbl, 1) + extractExpr(tw, expr.X, lbl, 0, false) + extractExpr(tw, expr.Index, lbl, 1, false) case *ast.IndexListExpr: if expr == nil { return @@ -993,30 +993,30 @@ func extractExpr(tw *trap.Writer, expr ast.Expr, parent trap.Label, idx int) { kind = dbscheme.GenericTypeInstantiationExpr.Index() } } - extractExpr(tw, expr.X, lbl, 0) + extractExpr(tw, expr.X, lbl, 0, false) extractExprs(tw, expr.Indices, lbl, 1, 1) case *ast.SliceExpr: if expr == nil { return } kind = dbscheme.SliceExpr.Index() - extractExpr(tw, expr.X, lbl, 0) - extractExpr(tw, expr.Low, lbl, 1) - extractExpr(tw, expr.High, lbl, 2) - extractExpr(tw, expr.Max, lbl, 3) + extractExpr(tw, expr.X, lbl, 0, false) + extractExpr(tw, expr.Low, lbl, 1, false) + extractExpr(tw, expr.High, lbl, 2, false) + extractExpr(tw, expr.Max, lbl, 3, false) case *ast.TypeAssertExpr: if expr == nil { return } kind = dbscheme.TypeAssertExpr.Index() - extractExpr(tw, expr.X, lbl, 0) - extractExpr(tw, expr.Type, lbl, 1) + extractExpr(tw, expr.X, lbl, 0, false) + extractExpr(tw, expr.Type, lbl, 1, false) case *ast.CallExpr: if expr == nil { return } kind = dbscheme.CallOrConversionExpr.Index() - extractExpr(tw, expr.Fun, lbl, 0) + extractExpr(tw, expr.Fun, lbl, 0, false) extractExprs(tw, expr.Args, lbl, 1, 1) if expr.Ellipsis.IsValid() { dbscheme.HasEllipsisTable.Emit(tw, lbl) @@ -1026,14 +1026,14 @@ func extractExpr(tw *trap.Writer, expr ast.Expr, parent trap.Label, idx int) { return } kind = dbscheme.StarExpr.Index() - extractExpr(tw, expr.X, lbl, 0) + extractExpr(tw, expr.X, lbl, 0, false) case *ast.KeyValueExpr: if expr == nil { return } kind = dbscheme.KeyValueExpr.Index() - extractExpr(tw, expr.Key, lbl, 0) - extractExpr(tw, expr.Value, lbl, 1) + extractExpr(tw, expr.Key, lbl, 0, false) + extractExpr(tw, expr.Value, lbl, 1, false) case *ast.UnaryExpr: if expr == nil { return @@ -1047,7 +1047,7 @@ func extractExpr(tw *trap.Writer, expr ast.Expr, parent trap.Label, idx int) { } kind = tp.Index() } - extractExpr(tw, expr.X, lbl, 0) + extractExpr(tw, expr.X, lbl, 0, false) case *ast.BinaryExpr: if expr == nil { return @@ -1062,16 +1062,17 @@ func extractExpr(tw *trap.Writer, expr ast.Expr, parent trap.Label, idx int) { log.Fatalf("unsupported binary operator %s", expr.Op) } kind = tp.Index() - extractExpr(tw, expr.X, lbl, 0) - extractExpr(tw, expr.Y, lbl, 1) + skipLeft := skipExtractingValueForLeftOperand(tw, expr) + extractExpr(tw, expr.X, lbl, 0, skipLeft) + extractExpr(tw, expr.Y, lbl, 1, false) } case *ast.ArrayType: if expr == nil { return } kind = dbscheme.ArrayTypeExpr.Index() - extractExpr(tw, expr.Len, lbl, 0) - extractExpr(tw, expr.Elt, lbl, 1) + extractExpr(tw, expr.Len, lbl, 0, false) + extractExpr(tw, expr.Elt, lbl, 1, false) case *ast.StructType: if expr == nil { return @@ -1100,8 +1101,8 @@ func extractExpr(tw *trap.Writer, expr ast.Expr, parent trap.Label, idx int) { return } kind = dbscheme.MapTypeExpr.Index() - extractExpr(tw, expr.Key, lbl, 0) - extractExpr(tw, expr.Value, lbl, 1) + extractExpr(tw, expr.Key, lbl, 0, false) + extractExpr(tw, expr.Value, lbl, 1, false) case *ast.ChanType: if expr == nil { return @@ -1111,13 +1112,15 @@ func extractExpr(tw *trap.Writer, expr ast.Expr, parent trap.Label, idx int) { log.Fatalf("unsupported channel direction %v", expr.Dir) } kind = tp.Index() - extractExpr(tw, expr.Value, lbl, 0) + extractExpr(tw, expr.Value, lbl, 0, false) default: log.Fatalf("unknown expression of type %T", expr) } dbscheme.ExprsTable.Emit(tw, lbl, kind, parent, idx) extractNodeLocation(tw, expr, lbl) - extractValueOf(tw, expr, lbl) + if !skipExtractingValue { + extractValueOf(tw, expr, lbl) + } } // extractExprs extracts AST information for a list of expressions, which are children of @@ -1128,7 +1131,7 @@ func extractExpr(tw *trap.Writer, expr ast.Expr, parent trap.Label, idx int) { func extractExprs(tw *trap.Writer, exprs []ast.Expr, parent trap.Label, idx int, dir int) { if exprs != nil { for _, expr := range exprs { - extractExpr(tw, expr, parent, idx) + extractExpr(tw, expr, parent, idx, false) idx += dir } } @@ -1194,11 +1197,11 @@ func extractFields(tw *trap.Writer, fields *ast.FieldList, parent trap.Label, id extractNodeLocation(tw, field, lbl) if field.Names != nil { for i, name := range field.Names { - extractExpr(tw, name, lbl, i+1) + extractExpr(tw, name, lbl, i+1, false) } } - extractExpr(tw, field.Type, lbl, 0) - extractExpr(tw, field.Tag, lbl, -1) + extractExpr(tw, field.Type, lbl, 0, false) + extractExpr(tw, field.Tag, lbl, -1, false) extractDoc(tw, field.Doc, lbl) idx += dir } @@ -1229,21 +1232,21 @@ func extractStmt(tw *trap.Writer, stmt ast.Stmt, parent trap.Label, idx int) { return } kind = dbscheme.LabeledStmtType.Index() - extractExpr(tw, stmt.Label, lbl, 0) + extractExpr(tw, stmt.Label, lbl, 0, false) extractStmt(tw, stmt.Stmt, lbl, 1) case *ast.ExprStmt: if stmt == nil { return } kind = dbscheme.ExprStmtType.Index() - extractExpr(tw, stmt.X, lbl, 0) + extractExpr(tw, stmt.X, lbl, 0, false) case *ast.SendStmt: if stmt == nil { return } kind = dbscheme.SendStmtType.Index() - extractExpr(tw, stmt.Chan, lbl, 0) - extractExpr(tw, stmt.Value, lbl, 1) + extractExpr(tw, stmt.Chan, lbl, 0, false) + extractExpr(tw, stmt.Value, lbl, 1, false) case *ast.IncDecStmt: if stmt == nil { return @@ -1255,7 +1258,7 @@ func extractStmt(tw *trap.Writer, stmt ast.Stmt, parent trap.Label, idx int) { } else { log.Fatalf("unsupported increment/decrement operator %v", stmt.Tok) } - extractExpr(tw, stmt.X, lbl, 0) + extractExpr(tw, stmt.X, lbl, 0, false) case *ast.AssignStmt: if stmt == nil { return @@ -1272,13 +1275,13 @@ func extractStmt(tw *trap.Writer, stmt ast.Stmt, parent trap.Label, idx int) { return } kind = dbscheme.GoStmtType.Index() - extractExpr(tw, stmt.Call, lbl, 0) + extractExpr(tw, stmt.Call, lbl, 0, false) case *ast.DeferStmt: if stmt == nil { return } kind = dbscheme.DeferStmtType.Index() - extractExpr(tw, stmt.Call, lbl, 0) + extractExpr(tw, stmt.Call, lbl, 0, false) case *ast.ReturnStmt: kind = dbscheme.ReturnStmtType.Index() extractExprs(tw, stmt.Results, lbl, 0, 1) @@ -1298,7 +1301,7 @@ func extractStmt(tw *trap.Writer, stmt ast.Stmt, parent trap.Label, idx int) { default: log.Fatalf("unsupported branch statement type %v", stmt.Tok) } - extractExpr(tw, stmt.Label, lbl, 0) + extractExpr(tw, stmt.Label, lbl, 0, false) case *ast.BlockStmt: if stmt == nil { return @@ -1312,7 +1315,7 @@ func extractStmt(tw *trap.Writer, stmt ast.Stmt, parent trap.Label, idx int) { } kind = dbscheme.IfStmtType.Index() extractStmt(tw, stmt.Init, lbl, 0) - extractExpr(tw, stmt.Cond, lbl, 1) + extractExpr(tw, stmt.Cond, lbl, 1, false) extractStmt(tw, stmt.Body, lbl, 2) extractStmt(tw, stmt.Else, lbl, 3) emitScopeNodeInfo(tw, stmt, lbl) @@ -1330,7 +1333,7 @@ func extractStmt(tw *trap.Writer, stmt ast.Stmt, parent trap.Label, idx int) { } kind = dbscheme.ExprSwitchStmtType.Index() extractStmt(tw, stmt.Init, lbl, 0) - extractExpr(tw, stmt.Tag, lbl, 1) + extractExpr(tw, stmt.Tag, lbl, 1, false) extractStmt(tw, stmt.Body, lbl, 2) emitScopeNodeInfo(tw, stmt, lbl) case *ast.TypeSwitchStmt: @@ -1359,7 +1362,7 @@ func extractStmt(tw *trap.Writer, stmt ast.Stmt, parent trap.Label, idx int) { } kind = dbscheme.ForStmtType.Index() extractStmt(tw, stmt.Init, lbl, 0) - extractExpr(tw, stmt.Cond, lbl, 1) + extractExpr(tw, stmt.Cond, lbl, 1, false) extractStmt(tw, stmt.Post, lbl, 2) extractStmt(tw, stmt.Body, lbl, 3) emitScopeNodeInfo(tw, stmt, lbl) @@ -1368,9 +1371,9 @@ func extractStmt(tw *trap.Writer, stmt ast.Stmt, parent trap.Label, idx int) { return } kind = dbscheme.RangeStmtType.Index() - extractExpr(tw, stmt.Key, lbl, 0) - extractExpr(tw, stmt.Value, lbl, 1) - extractExpr(tw, stmt.X, lbl, 2) + extractExpr(tw, stmt.Key, lbl, 0, false) + extractExpr(tw, stmt.Value, lbl, 1, false) + extractExpr(tw, stmt.X, lbl, 2, false) extractStmt(tw, stmt.Body, lbl, 3) emitScopeNodeInfo(tw, stmt, lbl) default: @@ -1428,8 +1431,8 @@ func extractDecl(tw *trap.Writer, decl ast.Decl, parent trap.Label, idx int) { } kind = dbscheme.FuncDeclType.Index() extractFields(tw, decl.Recv, lbl, -1, -1) - extractExpr(tw, decl.Name, lbl, 0) - extractExpr(tw, decl.Type, lbl, 1) + extractExpr(tw, decl.Name, lbl, 0, false) + extractExpr(tw, decl.Type, lbl, 1, false) extractStmt(tw, decl.Body, lbl, 2) extractDoc(tw, decl.Doc, lbl) extractTypeParamDecls(tw, decl.Type.TypeParams, lbl) @@ -1455,8 +1458,8 @@ func extractSpec(tw *trap.Writer, spec ast.Spec, parent trap.Label, idx int) { return } kind = dbscheme.ImportSpecType.Index() - extractExpr(tw, spec.Name, lbl, 0) - extractExpr(tw, spec.Path, lbl, 1) + extractExpr(tw, spec.Name, lbl, 0, false) + extractExpr(tw, spec.Path, lbl, 1, false) extractDoc(tw, spec.Doc, lbl) case *ast.ValueSpec: if spec == nil { @@ -1464,9 +1467,9 @@ func extractSpec(tw *trap.Writer, spec ast.Spec, parent trap.Label, idx int) { } kind = dbscheme.ValueSpecType.Index() for i, name := range spec.Names { - extractExpr(tw, name, lbl, -(1 + i)) + extractExpr(tw, name, lbl, -(1 + i), false) } - extractExpr(tw, spec.Type, lbl, 0) + extractExpr(tw, spec.Type, lbl, 0, false) extractExprs(tw, spec.Values, lbl, 1, 1) extractDoc(tw, spec.Doc, lbl) case *ast.TypeSpec: @@ -1478,9 +1481,9 @@ func extractSpec(tw *trap.Writer, spec ast.Spec, parent trap.Label, idx int) { } else { kind = dbscheme.TypeDefSpecType.Index() } - extractExpr(tw, spec.Name, lbl, 0) + extractExpr(tw, spec.Name, lbl, 0, false) extractTypeParamDecls(tw, spec.TypeParams, lbl) - extractExpr(tw, spec.Type, lbl, 1) + extractExpr(tw, spec.Type, lbl, 1, false) extractDoc(tw, spec.Doc, lbl) } dbscheme.SpecsTable.Emit(tw, lbl, kind, parent, idx) @@ -1909,7 +1912,7 @@ func flattenBinaryExprTree(tw *trap.Writer, e ast.Expr, parent trap.Label, idx i idx = flattenBinaryExprTree(tw, binaryexpr.X, parent, idx) idx = flattenBinaryExprTree(tw, binaryexpr.Y, parent, idx) } else { - extractExpr(tw, e, parent, idx) + extractExpr(tw, e, parent, idx, false) idx = idx + 1 } return idx @@ -1931,10 +1934,10 @@ func extractTypeParamDecls(tw *trap.Writer, fields *ast.FieldList, parent trap.L extractNodeLocation(tw, field, lbl) if field.Names != nil { for i, name := range field.Names { - extractExpr(tw, name, lbl, i+1) + extractExpr(tw, name, lbl, i+1, false) } } - extractExpr(tw, field.Type, lbl, 0) + extractExpr(tw, field.Type, lbl, 0, false) extractDoc(tw, field.Doc, lbl) idx += 1 } @@ -2023,3 +2026,24 @@ func setTypeParamParent(tp *types.TypeParam, newobj types.Object) { log.Fatalf("Parent of type parameter '%s %s' being set to a different value: '%s' vs '%s'", tp.String(), tp.Constraint().String(), obj, newobj) } } + +// skipExtractingValueForLeftOperand returns true if the left operand of `be` +// should not have its value extracted because it is an intermediate value in a +// string concatenation - specifically that the right operand is a string +// literal +func skipExtractingValueForLeftOperand(tw *trap.Writer, be *ast.BinaryExpr) bool { + // check `be` has string type + tpVal := tw.Package.TypesInfo.Types[be] + if tpVal.Value == nil || tpVal.Value.Kind() != constant.String { + return false + } + // check that the right operand of `be` is a basic literal + if _, isBasicLit := be.Y.(*ast.BasicLit); !isBasicLit { + return false + } + // check that the left operand of `be` is not a basic literal + if _, isBasicLit := be.X.(*ast.BasicLit); isBasicLit { + return false + } + return true +} From 33c17313b413e9f742f825e5bc4c65609f0d1c0a Mon Sep 17 00:00:00 2001 From: Owen Mansel-Chan Date: Tue, 12 Mar 2024 11:59:10 +0000 Subject: [PATCH 2/2] Add test for not extracting values for intermediate string concatenations --- .../no-intermediate-strings/tst.expected | 11 +++++++++++ .../no-intermediate-strings/tst.go | 5 +++++ .../no-intermediate-strings/tst.ql | 17 +++++++++++++++++ 3 files changed, 33 insertions(+) create mode 100644 go/ql/test/extractor-tests/no-intermediate-strings/tst.expected create mode 100644 go/ql/test/extractor-tests/no-intermediate-strings/tst.go create mode 100644 go/ql/test/extractor-tests/no-intermediate-strings/tst.ql diff --git a/go/ql/test/extractor-tests/no-intermediate-strings/tst.expected b/go/ql/test/extractor-tests/no-intermediate-strings/tst.expected new file mode 100644 index 000000000000..9746d74fb904 --- /dev/null +++ b/go/ql/test/extractor-tests/no-intermediate-strings/tst.expected @@ -0,0 +1,11 @@ +| tst.go:4:6:4:8 | "a" | a | +| tst.go:4:6:4:14 | ...+... | | +| tst.go:4:6:4:20 | ...+... | | +| tst.go:4:6:4:26 | ...+... | | +| tst.go:4:6:4:32 | ...+... | | +| tst.go:4:6:4:38 | ...+... | abcdef | +| tst.go:4:12:4:14 | "b" | b | +| tst.go:4:18:4:20 | "c" | c | +| tst.go:4:24:4:26 | "d" | d | +| tst.go:4:30:4:32 | "e" | e | +| tst.go:4:36:4:38 | "f" | f | diff --git a/go/ql/test/extractor-tests/no-intermediate-strings/tst.go b/go/ql/test/extractor-tests/no-intermediate-strings/tst.go new file mode 100644 index 000000000000..c79c97a8e88d --- /dev/null +++ b/go/ql/test/extractor-tests/no-intermediate-strings/tst.go @@ -0,0 +1,5 @@ +package main + +func main() { + _ = "a" + "b" + "c" + "d" + "e" + "f" +} diff --git a/go/ql/test/extractor-tests/no-intermediate-strings/tst.ql b/go/ql/test/extractor-tests/no-intermediate-strings/tst.ql new file mode 100644 index 000000000000..6367ef51e707 --- /dev/null +++ b/go/ql/test/extractor-tests/no-intermediate-strings/tst.ql @@ -0,0 +1,17 @@ +import go + +string checkStringValue(Expr e) { + result = e.getStringValue() + or + not exists(e.getStringValue()) and result = "" +} + +from Expr e +where e.getType() instanceof StringType +// We should get string values for `"a"`, `"b"`, `"c"` and `"a" + "b" + "c" +// but not `"a" + "b"`. In the extractor we avoid storing the value of +// intermediate strings in string concatenations because in pathological cases +// this could lead to a quadratic blowup in the size of string values stored, +// which then causes performance problems when we iterate through all string +// values. +select e, checkStringValue(e)