Skip to content

Commit

Permalink
feat: migrate enums to new extractor
Browse files Browse the repository at this point in the history
fixes #1903
  • Loading branch information
worstell committed Jul 11, 2024
1 parent 09bfe4b commit 2fbabbd
Show file tree
Hide file tree
Showing 26 changed files with 887 additions and 735 deletions.
3 changes: 0 additions & 3 deletions buildengine/testdata/alpha/types.ftl.go

This file was deleted.

26 changes: 0 additions & 26 deletions buildengine/testdata/other/types.ftl.go

This file was deleted.

4 changes: 2 additions & 2 deletions buildengine/testdata/type_registry_main.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ func init() {
{{- range .SumTypes}}
reflection.SumType[{{.Discriminator}}](
{{- range .Variants}}
*new({{.Type}}),
*new({{.Name}}),
{{- end}}
),
{{- end}}
Expand Down
10 changes: 8 additions & 2 deletions go-runtime/compile/build-template/types.ftl.go.tmpl
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
{{- $moduleName := .Name -}}

// Code generated by FTL. DO NOT EDIT.
package {{.Name}}

Expand All @@ -8,14 +10,18 @@ import (
{{.Import}}
{{end}}
"github.com/TBD54566975/ftl/go-runtime/ftl/reflection"

{{ range typesImports . }}
"{{.}}"
{{- end}}
)

func init() {
reflection.Register(
{{- range .LocalSumTypes}}
reflection.SumType[{{.Discriminator}}](
reflection.SumType[{{ trimModuleQualifier $moduleName .Discriminator }}](
{{- range .Variants}}
*new({{.Name}}),
*new({{ trimModuleQualifier $moduleName .Name }}),
{{- end}}
),
{{- end}}
Expand Down
179 changes: 118 additions & 61 deletions go-runtime/compile/build.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"strings"
"unicode"

"github.com/alecthomas/types/optional"
sets "github.com/deckarep/golang-set/v2"
"golang.org/x/exp/maps"
"golang.org/x/mod/modfile"
Expand Down Expand Up @@ -70,6 +71,7 @@ type mainModuleContext struct {
type goSumType struct {
Discriminator string
Variants []goSumTypeVariant
fqName string
}

type goSumTypeVariant struct {
Expand Down Expand Up @@ -192,16 +194,19 @@ func Build(ctx context.Context, projectRootDir, moduleDir string, sch *schema.Sc
}
goVerbs = append(goVerbs, goverb)
}
sumTypes, goExternalTypes := getRegisteredTypes(result.Module, sch, result.NativeNames)
allSumTypes, goExternalTypes, err := getRegisteredTypes(result.Module, sch, result.NativeNames)
if err != nil {
return err
}
if err := internal.ScaffoldZip(buildTemplateFiles(), moduleDir, mainModuleContext{
GoVersion: goModVersion,
FTLVersion: ftlVersion,
Name: result.Module.Name,
SharedModulesPaths: sharedModulesPaths,
Verbs: goVerbs,
Replacements: replacements,
SumTypes: sumTypes,
LocalSumTypes: getLocalSumTypes(result.Module),
SumTypes: allSumTypes,
LocalSumTypes: getLocalSumTypes(result.Module, result.NativeNames),
ExternalGoTypes: goExternalTypes,
}, scaffolder.Exclude("^go.mod$"), scaffolder.Functions(funcs)); err != nil {
return err
Expand Down Expand Up @@ -430,12 +435,38 @@ var scaffoldFuncs = scaffolder.FuncMap{
imports.Add(strings.TrimPrefix(v.MustImport, "ftl/"))
}
for _, st := range ctx.SumTypes {
if i := strings.LastIndex(st.Discriminator, "."); i != -1 {
imports.Add(st.Discriminator[:i])
if i := strings.LastIndex(st.fqName, "."); i != -1 {
lessTypeName := strings.TrimSuffix(st.fqName, st.fqName[i:])
imports.Add(strings.TrimPrefix(lessTypeName, "ftl/"))
}
for _, v := range st.Variants {
if i := strings.LastIndex(v.Type, "."); i != -1 {
lessTypeName := strings.TrimSuffix(v.Type, v.Type[i:])
imports.Add(strings.TrimPrefix(lessTypeName, "ftl/"))
}
}
}
out := imports.ToSlice()
slices.Sort(out)
return out
},
"typesImports": func(ctx mainModuleContext) []string {
imports := sets.NewSet[string]()
for _, st := range ctx.LocalSumTypes {
if i := strings.LastIndex(st.fqName, "."); i != -1 {
lessTypeName := strings.TrimSuffix(st.fqName, st.fqName[i:])
// subpackage
if len(strings.Split(lessTypeName, "/")) > 2 {
imports.Add(lessTypeName)
}
}
for _, v := range st.Variants {
if i := strings.LastIndex(v.Type, "."); i != -1 {
imports.Add(v.Type[:i])
lessTypeName := strings.TrimSuffix(v.Type, v.Type[i:])
// subpackage
if len(strings.Split(lessTypeName, "/")) > 2 {
imports.Add(lessTypeName)
}
}
}
}
Expand Down Expand Up @@ -471,6 +502,12 @@ var scaffoldFuncs = scaffolder.FuncMap{
}
return out
},
"trimModuleQualifier": func(moduleName string, str string) string {
if strings.HasPrefix(str, moduleName+".") {
return strings.TrimPrefix(str, moduleName+".")
}
return str
},
}

func schemaType(t schema.Type) string {
Expand Down Expand Up @@ -643,50 +680,37 @@ func writeSchemaErrors(config moduleconfig.ModuleConfig, errors []*schema.Error)
return os.WriteFile(config.Abs().Errors, elBytes, 0600)
}

func getLocalSumTypes(module *schema.Module) []goSumType {
func getLocalSumTypes(module *schema.Module, nativeNames NativeNames) []goSumType {
sumTypes := make(map[string]goSumType)
for _, d := range module.Decls {
if e, ok := d.(*schema.Enum); ok && !e.IsValueEnum() {
variants := make([]goSumTypeVariant, 0, len(e.Variants))
for _, v := range e.Variants {
variants = append(variants, goSumTypeVariant{ //nolint:forcetypeassert
Name: v.Name,
})
}
sumTypes[e.Name] = goSumType{
Discriminator: e.Name,
Variants: variants,
}
e, ok := d.(*schema.Enum)
if !ok {
continue
}
if e.IsValueEnum() {
continue
}
if st, ok := getGoSumType(e, nativeNames).Get(); ok {
enumFqName := nativeNames[e]
sumTypes[enumFqName] = st
}
}
out := maps.Values(sumTypes)
slices.SortFunc(out, func(a, b goSumType) int {
return strings.Compare(a.Discriminator, b.Discriminator)
})
return out
return maps.Values(sumTypes)
}

func getRegisteredTypes(module *schema.Module, sch *schema.Schema, nativeNames NativeNames) ([]goSumType, []goExternalType) {
func getRegisteredTypes(module *schema.Module, sch *schema.Schema, nativeNames NativeNames) ([]goSumType, []goExternalType, error) {
sumTypes := make(map[string]goSumType)
goExternalTypes := make(map[string][]string)
for _, d := range module.Decls {
switch d := d.(type) {
case *schema.Enum:
if d.IsValueEnum() {
continue

}
variants := make([]goSumTypeVariant, 0, len(d.Variants))
for _, v := range d.Variants {
variants = append(variants, goSumTypeVariant{ //nolint:forcetypeassert
Name: v.Name,
Type: nativeNames[v],
SchemaType: v.Value.(*schema.TypeValue).Value,
})
}
stFqName := nativeNames[d]
sumTypes[stFqName] = goSumType{
Discriminator: nativeNames[d],
Variants: variants,
if st, ok := getGoSumType(d, nativeNames).Get(); ok {
enumFqName := nativeNames[d]
sumTypes[enumFqName] = st
}
case *schema.TypeAlias:
var fqName string
Expand All @@ -698,7 +722,10 @@ func getRegisteredTypes(module *schema.Module, sch *schema.Schema, nativeNames N
if fqName == "" {
continue
}
im, typ := getGoExternalType(fqName)
im, typ, err := getGoExternalType(fqName)
if err != nil {
return nil, nil, err
}
if _, ok := goExternalTypes[im]; !ok {
goExternalTypes[im] = []string{}
}
Expand Down Expand Up @@ -735,35 +762,44 @@ func getRegisteredTypes(module *schema.Module, sch *schema.Schema, nativeNames N
Types: types,
})
}
return out, externalTypes
return out, externalTypes, nil
}

func getGoExternalType(fqName string) (_import string, _type string) {
// package and directory names are the same (dir=bar, pkg=bar): "github.com/foo/bar.A"
// package and directory names differ (dir=bar, pkg=baz): "github.com/foo/bar.baz.A"
parts := strings.Split(fqName, "/")
lastPart := parts[len(parts)-1]
pkgParts := strings.Split(lastPart, ".")
if len(pkgParts) < 2 {
panic("unexpected Go qualified name format: " + fqName)
func getGoSumType(enum *schema.Enum, nativeNames NativeNames) optional.Option[goSumType] {
if enum.IsValueEnum() {
return optional.None[goSumType]()
}
variants := make([]goSumTypeVariant, 0, len(enum.Variants))
for _, v := range enum.Variants {
nativeName := nativeNames[v]
lastSlash := strings.LastIndex(nativeName, "/")
variants = append(variants, goSumTypeVariant{ //nolint:forcetypeassert
Name: nativeName[lastSlash+1:],
Type: nativeName,
SchemaType: v.Value.(*schema.TypeValue).Value,
})
}
stFqName := nativeNames[enum]
lastSlash := strings.LastIndex(stFqName, "/")
return optional.Some(goSumType{
Discriminator: stFqName[lastSlash+1:],
Variants: variants,
fqName: stFqName,
})
}

dirName := pkgParts[0]
typeName := pkgParts[len(pkgParts)-1]
pkg := pkgParts[len(pkgParts)-2]

// translate the fqName to a valid import
// e.g.:
// "github.com/foo/bar.A" -> "github.com/foo/bar"
// "github.com/foo/bar.baz.A" -> "baz github.com/foo/bar" (aliased because package and directory path differ)
dirPath := strings.TrimSuffix(fqName, lastPart) + dirName
im := fmt.Sprintf("%q", dirPath)
if len(pkgParts) > 2 {
// import has an alias with the real package name: `import baz "github.com/foo/bar"`
im = fmt.Sprintf("%s %s", pkg, im)
func getGoExternalType(fqName string) (_import string, _type string, err error) {
im, err := goImportFromQualifiedName(fqName)
if err != nil {
return "", "", err
}

return im, fmt.Sprintf("%s.%s", pkg, typeName)
pkg := im[strings.LastIndex(im, "/")+1:]
if i := strings.LastIndex(im, " "); i != -1 {
// import has an alias and this will be the package
pkg = im[:i]
}
typeName := fqName[strings.LastIndex(fqName, ".")+1:]
return im, fmt.Sprintf("%s.%s", pkg, typeName), nil
}

type externalEnum struct {
Expand Down Expand Up @@ -891,3 +927,24 @@ func goVerbFromQualifiedName(qualifiedName string) (goVerb, error) {
MustImport: pkgPath,
}, nil
}

// package and directory names are the same (dir=bar, pkg=bar): "github.com/foo/bar.A" => "github.com/foo/bar"
// package and directory names differ (dir=bar, pkg=baz): "github.com/foo/bar.baz.A" => "baz github.com/foo/bar"
func goImportFromQualifiedName(qualifiedName string) (string, error) {
lastDotIndex := strings.LastIndex(qualifiedName, ".")
if lastDotIndex == -1 {
return "", fmt.Errorf("invalid qualified type format %q", qualifiedName)
}

pkgPath := qualifiedName[:lastDotIndex]
pkgName := path.Base(pkgPath)

importAlias := ""
if lastDotIndex = strings.LastIndex(pkgName, "."); lastDotIndex != -1 {
pkgName = pkgName[lastDotIndex+1:]
pkgPath = pkgPath[:strings.LastIndex(pkgPath, ".")]
// package and path differ, so we need to alias the import
importAlias = pkgName + " "
}
return fmt.Sprintf("%s%q", importAlias, pkgPath), nil
}
Loading

0 comments on commit 2fbabbd

Please sign in to comment.