From 1e3f103eb335120fffa6bc026a03607b1b9a6cb2 Mon Sep 17 00:00:00 2001 From: Chris Smowton Date: Sat, 31 Aug 2024 18:25:31 +0100 Subject: [PATCH] Only extract transparent-alias versions of types when necessary --- go/extractor/extractor.go | 90 +++++++++++++++++++++++++++++++++++++-- 1 file changed, 87 insertions(+), 3 deletions(-) diff --git a/go/extractor/extractor.go b/go/extractor/extractor.go index 6996260ec0984..a2aced3e9fcbd 100644 --- a/go/extractor/extractor.go +++ b/go/extractor/extractor.go @@ -1513,17 +1513,96 @@ func extractType(tw *trap.Writer, tp types.Type) trap.Label { return extractTypeWithFlags(tw, tp, false) } +func containsAliasTypes(tp types.Type) bool { + switch tp := tp.(type) { + case *types.Basic: + return false + case *types.Array: + return containsAliasTypes(tp.Elem()) + case *types.Slice: + return containsAliasTypes(tp.Elem()) + case *types.Struct: + for i := 0; i < tp.NumFields(); i++ { + field := tp.Field(i) + if containsAliasTypes(field.Type()) { + return true + } + } + return false + case *types.Pointer: + return containsAliasTypes(tp.Elem()) + case *types.Interface: + for i := 0; i < tp.NumMethods(); i++ { + meth := tp.Method(i) + if containsAliasTypes(meth.Type()) { + return true + } + } + for i := 0; i < tp.NumEmbeddeds(); i++ { + if containsAliasTypes(tp.EmbeddedType(i)) { + return true + } + } + return false + case *types.Tuple: + for i := 0; i < tp.Len(); i++ { + if containsAliasTypes(tp.At(i).Type()) { + return true + } + } + return false + case *types.Signature: + params, results := tp.Params(), tp.Results() + if params != nil { + for i := 0; i < params.Len(); i++ { + param := params.At(i) + if containsAliasTypes(param.Type()) { + return true + } + } + } + if results != nil { + for i := 0; i < results.Len(); i++ { + result := results.At(i) + if containsAliasTypes(result.Type()) { + return true + } + } + } + return false + case *types.Map: + return containsAliasTypes(tp.Key()) || containsAliasTypes(tp.Elem()) + case *types.Chan: + return containsAliasTypes(tp.Elem()) + case *types.Named: + return false + case *types.TypeParam: + return false + case *types.Union: + for i := 0; i < tp.Len(); i++ { + term := tp.Term(i) + if containsAliasTypes(term.Type()) { + return true + } + } + return false + case *types.Alias: + return true + default: + log.Fatalf("unexpected type %T", tp) + } + return false +} + func extractTypeWithFlags(tw *trap.Writer, tp types.Type, transparentAliases bool) trap.Label { lbl, exists := getTypeLabelWithFlags(tw, tp, transparentAliases) if !exists { - if !transparentAliases { + if !transparentAliases && containsAliasTypes(tp) { // Ensure the (deep) underlying type is also extracted, so that it is // possible to implement deepUnalias in QL. // For example, if we had type A = int and type B = string, we would need // to extract map[string]int so that deepUnalias(map[B]A) has a real member // of @type to return. - // - // TODO: consider using a newtype to do this instead. extractTypeWithFlags(tw, tp, true) } var kind int @@ -1709,6 +1788,11 @@ func getTypeLabelWithFlags(tw *trap.Writer, tp types.Type, transparentAliases bo typeLabelKey := trap.TypeLabelsKey{Type: tp, TransparentAliases: transparentAliases} lbl, exists := tw.Labeler.TypeLabels[typeLabelKey] if !exists { + if transparentAliases && !containsAliasTypes(tp) { + // No aliases involved, so the label is the same as the non-transparent version + // of the same type. + return getTypeLabelWithFlags(tw, tp, false) + } switch tp := tp.(type) { case *types.Basic: lbl = tw.Labeler.GlobalID(fmt.Sprintf("%d;basictype", tp.Kind()))