From 07303afa388f686bc12dcf4a5caa14ce291fc80e Mon Sep 17 00:00:00 2001 From: Elizabeth Worstell Date: Tue, 16 Jul 2024 11:15:49 -0700 Subject: [PATCH] feat: use external types in generated stubs if mapping present --- .../.ftl.tmpl/go/main/main.go.tmpl | 8 +- .../compile/build-template/types.ftl.go.tmpl | 6 +- go-runtime/compile/build.go | 194 +++++++++++++----- .../external_module.go.tmpl | 2 +- go-runtime/schema/common/common.go | 1 - go-runtime/schema/extract.go | 2 +- go-runtime/schema/typealias/analyzer.go | 26 +-- 7 files changed, 160 insertions(+), 79 deletions(-) diff --git a/go-runtime/compile/build-template/.ftl.tmpl/go/main/main.go.tmpl b/go-runtime/compile/build-template/.ftl.tmpl/go/main/main.go.tmpl index b6411bfc8b..32c27567d1 100644 --- a/go-runtime/compile/build-template/.ftl.tmpl/go/main/main.go.tmpl +++ b/go-runtime/compile/build-template/.ftl.tmpl/go/main/main.go.tmpl @@ -4,14 +4,14 @@ package main import ( "context" -{{- range .ExternalGoTypes }} +{{- range .ExternalTypes }} {{.Import}} {{- end}} "github.com/TBD54566975/ftl/backend/protos/xyz/block/ftl/v1/ftlv1connect" "github.com/TBD54566975/ftl/common/plugin" -{{- if or .SumTypes .ExternalGoTypes }} +{{- if or .SumTypes .ExternalTypes }} "github.com/TBD54566975/ftl/go-runtime/ftl/reflection" {{- end }} "github.com/TBD54566975/ftl/go-runtime/server" @@ -19,7 +19,7 @@ import ( "ftl/{{.}}" {{- end}} ) -{{- if or .SumTypes .ExternalGoTypes }} +{{- if or .SumTypes .ExternalTypes }} func init() { reflection.Register( @@ -30,7 +30,7 @@ func init() { {{- end}} ), {{- end}} -{{- range .ExternalGoTypes}} +{{- range .ExternalTypes}} {{- range .Types}} reflection.ExternalType(*new({{.}})), {{- end}} diff --git a/go-runtime/compile/build-template/types.ftl.go.tmpl b/go-runtime/compile/build-template/types.ftl.go.tmpl index c4fa7cd78b..2ce5783af3 100644 --- a/go-runtime/compile/build-template/types.ftl.go.tmpl +++ b/go-runtime/compile/build-template/types.ftl.go.tmpl @@ -3,10 +3,10 @@ // Code generated by FTL. DO NOT EDIT. package {{.Name}} -{{- if or .LocalSumTypes .ExternalGoTypes }} +{{- if or .LocalSumTypes .ExternalTypes }} import ( -{{- range .ExternalGoTypes }} +{{- range .ExternalTypes }} {{.Import}} {{end}} "github.com/TBD54566975/ftl/go-runtime/ftl/reflection" @@ -25,7 +25,7 @@ func init() { {{- end}} ), {{- end}} -{{- range .ExternalGoTypes}} +{{- range .ExternalTypes}} {{- range .Types}} reflection.ExternalType(*new({{.}})), {{- end}} diff --git a/go-runtime/compile/build.go b/go-runtime/compile/build.go index 23b3b79f60..617226b718 100644 --- a/go-runtime/compile/build.go +++ b/go-runtime/compile/build.go @@ -65,7 +65,8 @@ type mainModuleContext struct { Replacements []*modfile.Replace SumTypes []goSumType LocalSumTypes []goSumType - ExternalGoTypes []goExternalType + ExternalTypes []goExternalType + LocalExternalTypes []goExternalType } type goSumType struct { @@ -194,10 +195,11 @@ func Build(ctx context.Context, projectRootDir, moduleDir string, sch *schema.Sc } goVerbs = append(goVerbs, goverb) } - allSumTypes, goExternalTypes, err := getRegisteredTypes(result.Module, sch, result.NativeNames) + localExternalTypes, err := getLocalExternalTypes(result.Module) if err != nil { return err } + allSumTypes, allExternalTypes, err := getRegisteredTypes(result.Module, sch, result.NativeNames) if err := internal.ScaffoldZip(buildTemplateFiles(), moduleDir, mainModuleContext{ GoVersion: goModVersion, FTLVersion: ftlVersion, @@ -207,7 +209,8 @@ func Build(ctx context.Context, projectRootDir, moduleDir string, sch *schema.Sc Replacements: replacements, SumTypes: allSumTypes, LocalSumTypes: getLocalSumTypes(result.Module, result.NativeNames), - ExternalGoTypes: goExternalTypes, + ExternalTypes: allExternalTypes, + LocalExternalTypes: localExternalTypes, }, scaffolder.Exclude("^go.mod$"), scaffolder.Functions(funcs)); err != nil { return err } @@ -387,6 +390,17 @@ var scaffoldFuncs = scaffolder.FuncMap{ if n.IsExported() { imports["github.com/TBD54566975/ftl/go-runtime/ftl"] = "" } + + case *schema.TypeAlias: + if n.IsExported() { + if im, _ := getGoExternalTypeForWidenedType(n); im != "" { + unquoted, err := strconv.Unquote(im) + if err != nil { + panic(err) + } + imports[unquoted] = "" + } + } default: } return next() @@ -508,6 +522,32 @@ var scaffoldFuncs = scaffolder.FuncMap{ } return str }, + "typeAliasType": func(m *schema.Module, t *schema.TypeAlias) string { + if _, goType := getGoExternalTypeForWidenedType(t); goType != "" { + return goType + } + return genType(m, t.Type) + }, +} + +func getGoExternalTypeForWidenedType(t *schema.TypeAlias) (_import string, _type string) { + var goType string + var im string + for _, md := range t.Metadata { + md, ok := md.(*schema.MetadataTypeMap) + if !ok { + continue + } + + if md.Runtime == "go" { + var err error + im, goType, err = getGoExternalType(md.NativeName) + if err != nil { + panic(err) + } + } + } + return im, goType } func schemaType(t schema.Type) string { @@ -698,20 +738,10 @@ func getLocalSumTypes(module *schema.Module, nativeNames NativeNames) []goSumTyp return maps.Values(sumTypes) } -func getRegisteredTypes(module *schema.Module, sch *schema.Schema, nativeNames NativeNames) ([]goSumType, []goExternalType, error) { - sumTypes := make(map[string]goSumType) - goExternalTypes := make(map[string][]string) +func getLocalExternalTypes(module *schema.Module) ([]goExternalType, error) { + types := make(map[string][]string) for _, d := range module.Decls { switch d := d.(type) { - case *schema.Enum: - if d.IsValueEnum() { - continue - - } - if st, ok := getGoSumType(d, nativeNames).Get(); ok { - enumFqName := nativeNames[d] - sumTypes[enumFqName] = st - } case *schema.TypeAlias: var fqName string for _, m := range d.Metadata { @@ -724,45 +754,91 @@ func getRegisteredTypes(module *schema.Module, sch *schema.Schema, nativeNames N } im, typ, err := getGoExternalType(fqName) if err != nil { - return nil, nil, err + return nil, err } - if _, ok := goExternalTypes[im]; !ok { - goExternalTypes[im] = []string{} + if _, ok := types[im]; !ok { + types[im] = []string{} } - goExternalTypes[im] = append(goExternalTypes[im], typ) + types[im] = append(types[im], typ) default: } } + var out []goExternalType + for im, types := range types { + out = append(out, goExternalType{ + Import: im, + Types: types, + }) + } + return out, nil +} +// getRegisteredTypesExternalToModule returns all sum types and external types that are not defined in the given module. +// These are the types that must be registered in the main module. +func getRegisteredTypes(module *schema.Module, sch *schema.Schema, nativeNames NativeNames) ([]goSumType, []goExternalType, error) { + sumTypes := make(map[string]goSumType) + externalTypes := make(map[string]sets.Set[string]) // register sum types from other modules - for _, e := range getExternalTypeEnums(module, sch) { - variants := make([]goSumTypeVariant, 0, len(e.resolved.Variants)) - for _, v := range e.resolved.Variants { - variants = append(variants, goSumTypeVariant{ //nolint:forcetypeassert - Name: v.Name, - Type: e.ref.Module + "." + v.Name, - SchemaType: v.Value.(*schema.TypeValue).Value, - }) - } - stFqName := e.ref.Module + "." + e.ref.Name - sumTypes[e.ref.ToRefKey().String()] = goSumType{ - Discriminator: stFqName, - Variants: variants, + for _, decl := range getRegisteredTypesExternalToModule(module, sch) { + switch d := decl.resolved.(type) { + case *schema.Enum: + variants := make([]goSumTypeVariant, 0, len(d.Variants)) + for _, v := range d.Variants { + variants = append(variants, goSumTypeVariant{ //nolint:forcetypeassert + Name: decl.ref.Module + "." + v.Name, + Type: "ftl/" + decl.ref.Module + "." + v.Name, + SchemaType: v.Value.(*schema.TypeValue).Value, + }) + } + stFqName := decl.ref.Module + "." + decl.ref.Name + sumTypes[decl.ref.ToRefKey().String()] = goSumType{ + Discriminator: stFqName, + Variants: variants, + } + case *schema.TypeAlias: + for _, m := range d.Metadata { + if m, ok := m.(*schema.MetadataTypeMap); ok && m.Runtime == "go" { + im, typ, err := getGoExternalType(m.NativeName) + if err != nil { + return nil, nil, err + } + if _, ok := externalTypes[im]; !ok { + externalTypes[im] = sets.NewSet[string]() + } + externalTypes[im].Add(typ) + } + } + default: } } - out := maps.Values(sumTypes) - slices.SortFunc(out, func(a, b goSumType) int { + for _, d := range getLocalSumTypes(module, nativeNames) { + sumTypes[d.fqName] = d + } + stOut := maps.Values(sumTypes) + slices.SortFunc(stOut, func(a, b goSumType) int { return strings.Compare(a.Discriminator, b.Discriminator) }) - var externalTypes []goExternalType - for im, types := range goExternalTypes { - externalTypes = append(externalTypes, goExternalType{ + localExternalTypes, err := getLocalExternalTypes(module) + if err != nil { + return nil, nil, err + } + for _, et := range localExternalTypes { + if _, ok := externalTypes[et.Import]; !ok { + externalTypes[et.Import] = sets.NewSet[string]() + } + externalTypes[et.Import].Append(et.Types...) + } + + var etOut []goExternalType + for im, types := range externalTypes { + etOut = append(etOut, goExternalType{ Import: im, - Types: types, + Types: types.ToSlice(), }) } - return out, externalTypes, nil + + return stOut, etOut, nil } func getGoSumType(enum *schema.Enum, nativeNames NativeNames) optional.Option[goSumType] { @@ -811,42 +887,56 @@ func getGoExternalType(fqName string) (_import string, _type string, err error) return im, fmt.Sprintf("%s.%s", pkg, typeName), nil } -type externalEnum struct { +type externalDecl struct { ref *schema.Ref - resolved *schema.Enum + resolved schema.Decl } -// getExternalTypeEnums resolve all type enum references in the full schema -func getExternalTypeEnums(module *schema.Module, sch *schema.Schema) []externalEnum { +// getRegisteredTypesExternalToModule returns all sum types and external types that are not defined in the given module. +// These types must be registered in the main module. +func getRegisteredTypesExternalToModule(module *schema.Module, sch *schema.Schema) []externalDecl { combinedSch := schema.Schema{ Modules: append(sch.Modules, module), } - var externalTypeEnums []externalEnum + var externalTypes []externalDecl err := schema.Visit(&combinedSch, func(n schema.Node, next func() error) error { ref, ok := n.(*schema.Ref) if !ok { return next() } - if ref.Module != "" && ref.Module != module.Name { - return next() - } decl, ok := sch.Resolve(ref).Get() if !ok { return next() } - if e, ok := decl.(*schema.Enum); ok && !e.IsValueEnum() { - externalTypeEnums = append(externalTypeEnums, externalEnum{ + switch d := decl.(type) { + case *schema.Enum: + if ref.Module != "" && ref.Module != module.Name { + return next() + } + if d.IsValueEnum() { + return next() + } + externalTypes = append(externalTypes, externalDecl{ + ref: ref, + resolved: d, + }) + case *schema.TypeAlias: + if len(d.Metadata) == 0 { + return next() + } + externalTypes = append(externalTypes, externalDecl{ ref: ref, - resolved: e, + resolved: d, }) + default: } return next() }) if err != nil { - panic(fmt.Sprintf("failed to resolve external type enums schema: %v", err)) + panic(fmt.Sprintf("failed to resolve external types and sum types external to the module schema: %v", err)) } - return externalTypeEnums + return externalTypes } // ExtractModuleSchema statically parses Go FTL module source into a schema.Module diff --git a/go-runtime/compile/external-module-template/.ftl/go/modules/{{ .Module.Name }}/external_module.go.tmpl b/go-runtime/compile/external-module-template/.ftl/go/modules/{{ .Module.Name }}/external_module.go.tmpl index 657e8f20ab..8b22db023d 100644 --- a/go-runtime/compile/external-module-template/.ftl/go/modules/{{ .Module.Name }}/external_module.go.tmpl +++ b/go-runtime/compile/external-module-template/.ftl/go/modules/{{ .Module.Name }}/external_module.go.tmpl @@ -45,7 +45,7 @@ func ({{.Name|title}}) {{$enumInterfaceFuncName}}() {} {{- end}} {{- else if is "TypeAlias" .}} //ftl:typealias -type {{.Name|title}} {{type $.Module .Type}} +type {{.Name|title}} {{typeAliasType $.Module .}} {{- else if is "Data" .}} type {{.Name|title}} {{- if .TypeParameters}}[ diff --git a/go-runtime/schema/common/common.go b/go-runtime/schema/common/common.go index 93fe1ca546..14f1569934 100644 --- a/go-runtime/schema/common/common.go +++ b/go-runtime/schema/common/common.go @@ -574,7 +574,6 @@ func FuncPathEquals(pass *analysis.Pass, callExpr *ast.CallExpr, path string) bo func ApplyMetadata[T schema.Decl](pass *analysis.Pass, obj types.Object, apply func(md *ExtractedMetadata)) bool { if md, ok := GetFactForObject[*ExtractedMetadata](pass, obj).Get(); ok { if _, ok = md.Type.(T); !ok && md.Type != nil { - NoEndColumnErrorf(pass, obj.Pos(), "schema declaration contains conflicting directives") return false } apply(md) diff --git a/go-runtime/schema/extract.go b/go-runtime/schema/extract.go index c72c358275..767f1d3a53 100644 --- a/go-runtime/schema/extract.go +++ b/go-runtime/schema/extract.go @@ -355,7 +355,7 @@ func goQualifiedNameForWidenedType(obj types.Object, metadata []schema.Metadata) nativeName = m.NativeName } } - if nativeName == "" { + if len(metadata) > 0 && nativeName == "" { return "", fmt.Errorf("missing Go native name in typemapped alias for %q", common.GetNativeName(obj)) } diff --git a/go-runtime/schema/typealias/analyzer.go b/go-runtime/schema/typealias/analyzer.go index 9ccd0697b5..409fe3b680 100644 --- a/go-runtime/schema/typealias/analyzer.go +++ b/go-runtime/schema/typealias/analyzer.go @@ -27,13 +27,13 @@ func Extract(pass *analysis.Pass, node *ast.TypeSpec, obj types.Object) optional Name: strcase.ToUpperCamel(obj.Name()), Type: schType, } - if common.ApplyMetadata[*schema.TypeAlias](pass, obj, func(md *common.ExtractedMetadata) { + var hasGoTypeMapping bool + common.ApplyMetadata[*schema.TypeAlias](pass, obj, func(md *common.ExtractedMetadata) { alias.Comments = md.Comments alias.Export = md.IsExported alias.Metadata = md.Metadata if len(md.Metadata) > 0 { - hasGoTypeMap := false nativeName := qualifiedNameFromSelectorExpr(pass, node.Type) if nativeName == "" { return @@ -43,31 +43,23 @@ func Extract(pass *analysis.Pass, node *ast.TypeSpec, obj types.Object) optional if mt.Runtime != "go" { continue } + hasGoTypeMapping = true if nativeName != mt.NativeName { common.Errorf(pass, node, "declared type %s in typemap does not match native type %s", mt.NativeName, nativeName) return } - hasGoTypeMap = true } else { common.Errorf(pass, node, "unexpected directive on typealias %s", m) } } - - // if this alias contains any type mappings, implicitly add a Go type mapping if not already present - if !hasGoTypeMap { - alias.Metadata = append(alias.Metadata, &schema.MetadataTypeMap{ - Pos: common.GoPosToSchemaPos(pass.Fset, obj.Pos()), - Runtime: "go", - NativeName: nativeName, - }) - } - alias.Type = &schema.Any{} } - }) { - return optional.Some(alias) - } else if _, ok := alias.Type.(*schema.Any); ok && - !strings.HasPrefix(qualifiedNameFromSelectorExpr(pass, node.Type), "ftl") { + }) + + // if widening an external type, implicitly add a Go type mapping if one does not exist + if _, ok := alias.Type.(*schema.Any); ok && + !strings.HasPrefix(qualifiedNameFromSelectorExpr(pass, node.Type), "ftl/") && + !hasGoTypeMapping { alias.Metadata = append(alias.Metadata, &schema.MetadataTypeMap{ Pos: common.GoPosToSchemaPos(pass.Fset, obj.Pos()), Runtime: "go",