Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: use external types in generated stubs if mapping present #2083

Merged
merged 1 commit into from
Jul 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,22 @@ package main
import (
"context"

{{- range .ExternalGoTypes }}
{{- range .ExternalTypes }}

{{.Import}}
{{- end}}

"github.com/TBD54566975/ftl/backend/protos/xyz/block/ftl/v1/ftlv1connect"
"github.com/TBD54566975/ftl/common/plugin"
{{- if or .SumTypes .ExternalGoTypes }}
{{- if or .SumTypes .ExternalTypes }}
"github.com/TBD54566975/ftl/go-runtime/ftl/reflection"
{{- end }}
"github.com/TBD54566975/ftl/go-runtime/server"
{{ range mainImports . }}
"ftl/{{.}}"
{{- end}}
)
{{- if or .SumTypes .ExternalGoTypes }}
{{- if or .SumTypes .ExternalTypes }}

func init() {
reflection.Register(
Expand All @@ -30,7 +30,7 @@ func init() {
{{- end}}
),
{{- end}}
{{- range .ExternalGoTypes}}
{{- range .ExternalTypes}}
{{- range .Types}}
reflection.ExternalType(*new({{.}})),
{{- end}}
Expand Down
6 changes: 3 additions & 3 deletions go-runtime/compile/build-template/types.ftl.go.tmpl
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
// Code generated by FTL. DO NOT EDIT.
package {{.Name}}

{{- if or .LocalSumTypes .ExternalGoTypes }}
{{- if or .LocalSumTypes .ExternalTypes }}

import (
{{- range .ExternalGoTypes }}
{{- range .ExternalTypes }}
{{.Import}}
{{end}}
"github.com/TBD54566975/ftl/go-runtime/ftl/reflection"
Expand All @@ -25,7 +25,7 @@ func init() {
{{- end}}
),
{{- end}}
{{- range .ExternalGoTypes}}
{{- range .ExternalTypes}}
{{- range .Types}}
reflection.ExternalType(*new({{.}})),
{{- end}}
Expand Down
194 changes: 142 additions & 52 deletions go-runtime/compile/build.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@ type mainModuleContext struct {
Replacements []*modfile.Replace
SumTypes []goSumType
LocalSumTypes []goSumType
ExternalGoTypes []goExternalType
ExternalTypes []goExternalType
LocalExternalTypes []goExternalType
}

type goSumType struct {
Expand Down Expand Up @@ -194,10 +195,11 @@ func Build(ctx context.Context, projectRootDir, moduleDir string, sch *schema.Sc
}
goVerbs = append(goVerbs, goverb)
}
allSumTypes, goExternalTypes, err := getRegisteredTypes(result.Module, sch, result.NativeNames)
localExternalTypes, err := getLocalExternalTypes(result.Module)
if err != nil {
return err
}
allSumTypes, allExternalTypes, err := getRegisteredTypes(result.Module, sch, result.NativeNames)
if err := internal.ScaffoldZip(buildTemplateFiles(), moduleDir, mainModuleContext{
GoVersion: goModVersion,
FTLVersion: ftlVersion,
Expand All @@ -207,7 +209,8 @@ func Build(ctx context.Context, projectRootDir, moduleDir string, sch *schema.Sc
Replacements: replacements,
SumTypes: allSumTypes,
LocalSumTypes: getLocalSumTypes(result.Module, result.NativeNames),
ExternalGoTypes: goExternalTypes,
ExternalTypes: allExternalTypes,
LocalExternalTypes: localExternalTypes,
}, scaffolder.Exclude("^go.mod$"), scaffolder.Functions(funcs)); err != nil {
return err
}
Expand Down Expand Up @@ -387,6 +390,17 @@ var scaffoldFuncs = scaffolder.FuncMap{
if n.IsExported() {
imports["github.com/TBD54566975/ftl/go-runtime/ftl"] = ""
}

case *schema.TypeAlias:
if n.IsExported() {
if im, _ := getGoExternalTypeForWidenedType(n); im != "" {
unquoted, err := strconv.Unquote(im)
if err != nil {
panic(err)
}
imports[unquoted] = ""
}
}
default:
}
return next()
Expand Down Expand Up @@ -508,6 +522,32 @@ var scaffoldFuncs = scaffolder.FuncMap{
}
return str
},
"typeAliasType": func(m *schema.Module, t *schema.TypeAlias) string {
if _, goType := getGoExternalTypeForWidenedType(t); goType != "" {
return goType
}
return genType(m, t.Type)
},
}

func getGoExternalTypeForWidenedType(t *schema.TypeAlias) (_import string, _type string) {
var goType string
var im string
for _, md := range t.Metadata {
md, ok := md.(*schema.MetadataTypeMap)
if !ok {
continue
}

if md.Runtime == "go" {
var err error
im, goType, err = getGoExternalType(md.NativeName)
if err != nil {
panic(err)
}
}
}
return im, goType
}

func schemaType(t schema.Type) string {
Expand Down Expand Up @@ -698,20 +738,10 @@ func getLocalSumTypes(module *schema.Module, nativeNames NativeNames) []goSumTyp
return maps.Values(sumTypes)
}

func getRegisteredTypes(module *schema.Module, sch *schema.Schema, nativeNames NativeNames) ([]goSumType, []goExternalType, error) {
sumTypes := make(map[string]goSumType)
goExternalTypes := make(map[string][]string)
func getLocalExternalTypes(module *schema.Module) ([]goExternalType, error) {
types := make(map[string][]string)
for _, d := range module.Decls {
switch d := d.(type) {
case *schema.Enum:
if d.IsValueEnum() {
continue

}
if st, ok := getGoSumType(d, nativeNames).Get(); ok {
enumFqName := nativeNames[d]
sumTypes[enumFqName] = st
}
case *schema.TypeAlias:
var fqName string
for _, m := range d.Metadata {
Expand All @@ -724,45 +754,91 @@ func getRegisteredTypes(module *schema.Module, sch *schema.Schema, nativeNames N
}
im, typ, err := getGoExternalType(fqName)
if err != nil {
return nil, nil, err
return nil, err
}
if _, ok := goExternalTypes[im]; !ok {
goExternalTypes[im] = []string{}
if _, ok := types[im]; !ok {
types[im] = []string{}
}
goExternalTypes[im] = append(goExternalTypes[im], typ)
types[im] = append(types[im], typ)
default:
}
}
var out []goExternalType
for im, types := range types {
out = append(out, goExternalType{
Import: im,
Types: types,
})
}
return out, nil
}

// getRegisteredTypesExternalToModule returns all sum types and external types that are not defined in the given module.
// These are the types that must be registered in the main module.
func getRegisteredTypes(module *schema.Module, sch *schema.Schema, nativeNames NativeNames) ([]goSumType, []goExternalType, error) {
sumTypes := make(map[string]goSumType)
externalTypes := make(map[string]sets.Set[string])
// register sum types from other modules
for _, e := range getExternalTypeEnums(module, sch) {
variants := make([]goSumTypeVariant, 0, len(e.resolved.Variants))
for _, v := range e.resolved.Variants {
variants = append(variants, goSumTypeVariant{ //nolint:forcetypeassert
Name: v.Name,
Type: e.ref.Module + "." + v.Name,
SchemaType: v.Value.(*schema.TypeValue).Value,
})
}
stFqName := e.ref.Module + "." + e.ref.Name
sumTypes[e.ref.ToRefKey().String()] = goSumType{
Discriminator: stFqName,
Variants: variants,
for _, decl := range getRegisteredTypesExternalToModule(module, sch) {
switch d := decl.resolved.(type) {
case *schema.Enum:
variants := make([]goSumTypeVariant, 0, len(d.Variants))
for _, v := range d.Variants {
variants = append(variants, goSumTypeVariant{ //nolint:forcetypeassert
Name: decl.ref.Module + "." + v.Name,
Type: "ftl/" + decl.ref.Module + "." + v.Name,
SchemaType: v.Value.(*schema.TypeValue).Value,
})
}
stFqName := decl.ref.Module + "." + decl.ref.Name
sumTypes[decl.ref.ToRefKey().String()] = goSumType{
Discriminator: stFqName,
Variants: variants,
}
case *schema.TypeAlias:
for _, m := range d.Metadata {
if m, ok := m.(*schema.MetadataTypeMap); ok && m.Runtime == "go" {
im, typ, err := getGoExternalType(m.NativeName)
if err != nil {
return nil, nil, err
}
if _, ok := externalTypes[im]; !ok {
externalTypes[im] = sets.NewSet[string]()
}
externalTypes[im].Add(typ)
}
}
default:
}
}
out := maps.Values(sumTypes)
slices.SortFunc(out, func(a, b goSumType) int {
for _, d := range getLocalSumTypes(module, nativeNames) {
sumTypes[d.fqName] = d
}
stOut := maps.Values(sumTypes)
slices.SortFunc(stOut, func(a, b goSumType) int {
return strings.Compare(a.Discriminator, b.Discriminator)
})

var externalTypes []goExternalType
for im, types := range goExternalTypes {
externalTypes = append(externalTypes, goExternalType{
localExternalTypes, err := getLocalExternalTypes(module)
if err != nil {
return nil, nil, err
}
for _, et := range localExternalTypes {
if _, ok := externalTypes[et.Import]; !ok {
externalTypes[et.Import] = sets.NewSet[string]()
}
externalTypes[et.Import].Append(et.Types...)
}

var etOut []goExternalType
for im, types := range externalTypes {
etOut = append(etOut, goExternalType{
Import: im,
Types: types,
Types: types.ToSlice(),
})
}
return out, externalTypes, nil

return stOut, etOut, nil
}

func getGoSumType(enum *schema.Enum, nativeNames NativeNames) optional.Option[goSumType] {
Expand Down Expand Up @@ -811,42 +887,56 @@ func getGoExternalType(fqName string) (_import string, _type string, err error)
return im, fmt.Sprintf("%s.%s", pkg, typeName), nil
}

type externalEnum struct {
type externalDecl struct {
ref *schema.Ref
resolved *schema.Enum
resolved schema.Decl
}

// getExternalTypeEnums resolve all type enum references in the full schema
func getExternalTypeEnums(module *schema.Module, sch *schema.Schema) []externalEnum {
// getRegisteredTypesExternalToModule returns all sum types and external types that are not defined in the given module.
// These types must be registered in the main module.
func getRegisteredTypesExternalToModule(module *schema.Module, sch *schema.Schema) []externalDecl {
combinedSch := schema.Schema{
Modules: append(sch.Modules, module),
}
var externalTypeEnums []externalEnum
var externalTypes []externalDecl
err := schema.Visit(&combinedSch, func(n schema.Node, next func() error) error {
ref, ok := n.(*schema.Ref)
if !ok {
return next()
}
if ref.Module != "" && ref.Module != module.Name {
return next()
}

decl, ok := sch.Resolve(ref).Get()
if !ok {
return next()
}
if e, ok := decl.(*schema.Enum); ok && !e.IsValueEnum() {
externalTypeEnums = append(externalTypeEnums, externalEnum{
switch d := decl.(type) {
case *schema.Enum:
if ref.Module != "" && ref.Module != module.Name {
return next()
}
if d.IsValueEnum() {
return next()
}
externalTypes = append(externalTypes, externalDecl{
ref: ref,
resolved: d,
})
case *schema.TypeAlias:
if len(d.Metadata) == 0 {
return next()
}
externalTypes = append(externalTypes, externalDecl{
ref: ref,
resolved: e,
resolved: d,
})
default:
}
return next()
})
if err != nil {
panic(fmt.Sprintf("failed to resolve external type enums schema: %v", err))
panic(fmt.Sprintf("failed to resolve external types and sum types external to the module schema: %v", err))
}
return externalTypeEnums
return externalTypes
}

// ExtractModuleSchema statically parses Go FTL module source into a schema.Module
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ func ({{.Name|title}}) {{$enumInterfaceFuncName}}() {}
{{- end}}
{{- else if is "TypeAlias" .}}
//ftl:typealias
type {{.Name|title}} {{type $.Module .Type}}
type {{.Name|title}} {{typeAliasType $.Module .}}
{{- else if is "Data" .}}
type {{.Name|title}}
{{- if .TypeParameters}}[
Expand Down
1 change: 0 additions & 1 deletion go-runtime/schema/common/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -574,7 +574,6 @@ func FuncPathEquals(pass *analysis.Pass, callExpr *ast.CallExpr, path string) bo
func ApplyMetadata[T schema.Decl](pass *analysis.Pass, obj types.Object, apply func(md *ExtractedMetadata)) bool {
if md, ok := GetFactForObject[*ExtractedMetadata](pass, obj).Get(); ok {
if _, ok = md.Type.(T); !ok && md.Type != nil {
NoEndColumnErrorf(pass, obj.Pos(), "schema declaration contains conflicting directives")
return false
}
apply(md)
Expand Down
2 changes: 1 addition & 1 deletion go-runtime/schema/extract.go
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,7 @@ func goQualifiedNameForWidenedType(obj types.Object, metadata []schema.Metadata)
nativeName = m.NativeName
}
}
if nativeName == "" {
if len(metadata) > 0 && nativeName == "" {
return "", fmt.Errorf("missing Go native name in typemapped alias for %q",
common.GetNativeName(obj))
}
Expand Down
Loading
Loading