diff --git a/buildengine/testdata/alpha/types.ftl.go b/buildengine/testdata/alpha/types.ftl.go deleted file mode 100644 index 75e46a44b2..0000000000 --- a/buildengine/testdata/alpha/types.ftl.go +++ /dev/null @@ -1,3 +0,0 @@ -// Code generated by FTL. DO NOT EDIT. -package alpha - diff --git a/buildengine/testdata/other/types.ftl.go b/buildengine/testdata/other/types.ftl.go deleted file mode 100644 index c1e1d89f4b..0000000000 --- a/buildengine/testdata/other/types.ftl.go +++ /dev/null @@ -1,26 +0,0 @@ -// Code generated by FTL. DO NOT EDIT. -package other - -import "github.com/TBD54566975/ftl/go-runtime/ftl/reflection" - -func init() { - reflection.Register( - reflection.SumType[SecondTypeEnum]( - *new(A), - *new(B), - ), - reflection.SumType[TypeEnum]( - *new(MyBool), - *new(MyBytes), - *new(MyFloat), - *new(MyInt), - *new(MyTime), - *new(MyList), - *new(MyMap), - *new(MyString), - *new(MyStruct), - *new(MyOption), - *new(MyUnit), - ), - ) -} diff --git a/buildengine/testdata/type_registry_main.go b/buildengine/testdata/type_registry_main.go index 28722bfd7f..2a051f0718 100644 --- a/buildengine/testdata/type_registry_main.go +++ b/buildengine/testdata/type_registry_main.go @@ -23,12 +23,12 @@ func init() { *new(other.MyBytes), *new(other.MyFloat), *new(other.MyInt), - *new(other.MyTime), *new(other.MyList), *new(other.MyMap), + *new(other.MyOption), *new(other.MyString), *new(other.MyStruct), - *new(other.MyOption), + *new(other.MyTime), *new(other.MyUnit), ), ) diff --git a/go-runtime/compile/build-template/.ftl.tmpl/go/main/main.go.tmpl b/go-runtime/compile/build-template/.ftl.tmpl/go/main/main.go.tmpl index b28fb7a8fa..b6411bfc8b 100644 --- a/go-runtime/compile/build-template/.ftl.tmpl/go/main/main.go.tmpl +++ b/go-runtime/compile/build-template/.ftl.tmpl/go/main/main.go.tmpl @@ -26,7 +26,7 @@ func init() { {{- range .SumTypes}} reflection.SumType[{{.Discriminator}}]( {{- range .Variants}} - *new({{.Type}}), + *new({{.Name}}), {{- end}} ), {{- end}} diff --git a/go-runtime/compile/build-template/types.ftl.go.tmpl b/go-runtime/compile/build-template/types.ftl.go.tmpl index c474a4c967..c4fa7cd78b 100644 --- a/go-runtime/compile/build-template/types.ftl.go.tmpl +++ b/go-runtime/compile/build-template/types.ftl.go.tmpl @@ -1,3 +1,5 @@ +{{- $moduleName := .Name -}} + // Code generated by FTL. DO NOT EDIT. package {{.Name}} @@ -8,14 +10,18 @@ import ( {{.Import}} {{end}} "github.com/TBD54566975/ftl/go-runtime/ftl/reflection" + +{{ range typesImports . }} + "{{.}}" +{{- end}} ) func init() { reflection.Register( {{- range .LocalSumTypes}} - reflection.SumType[{{.Discriminator}}]( + reflection.SumType[{{ trimModuleQualifier $moduleName .Discriminator }}]( {{- range .Variants}} - *new({{.Name}}), + *new({{ trimModuleQualifier $moduleName .Name }}), {{- end}} ), {{- end}} diff --git a/go-runtime/compile/build.go b/go-runtime/compile/build.go index c6de29bb50..37ba8ad9d9 100644 --- a/go-runtime/compile/build.go +++ b/go-runtime/compile/build.go @@ -13,6 +13,7 @@ import ( "strings" "unicode" + "github.com/alecthomas/types/optional" sets "github.com/deckarep/golang-set/v2" "golang.org/x/exp/maps" "golang.org/x/mod/modfile" @@ -70,6 +71,7 @@ type mainModuleContext struct { type goSumType struct { Discriminator string Variants []goSumTypeVariant + fqName string } type goSumTypeVariant struct { @@ -192,7 +194,10 @@ func Build(ctx context.Context, projectRootDir, moduleDir string, sch *schema.Sc } goVerbs = append(goVerbs, goverb) } - sumTypes, goExternalTypes := getRegisteredTypes(result.Module, sch, result.NativeNames) + allSumTypes, goExternalTypes, err := getRegisteredTypes(result.Module, sch, result.NativeNames) + if err != nil { + return err + } if err := internal.ScaffoldZip(buildTemplateFiles(), moduleDir, mainModuleContext{ GoVersion: goModVersion, FTLVersion: ftlVersion, @@ -200,8 +205,8 @@ func Build(ctx context.Context, projectRootDir, moduleDir string, sch *schema.Sc SharedModulesPaths: sharedModulesPaths, Verbs: goVerbs, Replacements: replacements, - SumTypes: sumTypes, - LocalSumTypes: getLocalSumTypes(result.Module), + SumTypes: allSumTypes, + LocalSumTypes: getLocalSumTypes(result.Module, result.NativeNames), ExternalGoTypes: goExternalTypes, }, scaffolder.Exclude("^go.mod$"), scaffolder.Functions(funcs)); err != nil { return err @@ -430,12 +435,38 @@ var scaffoldFuncs = scaffolder.FuncMap{ imports.Add(strings.TrimPrefix(v.MustImport, "ftl/")) } for _, st := range ctx.SumTypes { - if i := strings.LastIndex(st.Discriminator, "."); i != -1 { - imports.Add(st.Discriminator[:i]) + if i := strings.LastIndex(st.fqName, "."); i != -1 { + lessTypeName := strings.TrimSuffix(st.fqName, st.fqName[i:]) + imports.Add(strings.TrimPrefix(lessTypeName, "ftl/")) + } + for _, v := range st.Variants { + if i := strings.LastIndex(v.Type, "."); i != -1 { + lessTypeName := strings.TrimSuffix(v.Type, v.Type[i:]) + imports.Add(strings.TrimPrefix(lessTypeName, "ftl/")) + } + } + } + out := imports.ToSlice() + slices.Sort(out) + return out + }, + "typesImports": func(ctx mainModuleContext) []string { + imports := sets.NewSet[string]() + for _, st := range ctx.LocalSumTypes { + if i := strings.LastIndex(st.fqName, "."); i != -1 { + lessTypeName := strings.TrimSuffix(st.fqName, st.fqName[i:]) + // subpackage + if len(strings.Split(lessTypeName, "/")) > 2 { + imports.Add(lessTypeName) + } } for _, v := range st.Variants { if i := strings.LastIndex(v.Type, "."); i != -1 { - imports.Add(v.Type[:i]) + lessTypeName := strings.TrimSuffix(v.Type, v.Type[i:]) + // subpackage + if len(strings.Split(lessTypeName, "/")) > 2 { + imports.Add(lessTypeName) + } } } } @@ -471,6 +502,12 @@ var scaffoldFuncs = scaffolder.FuncMap{ } return out }, + "trimModuleQualifier": func(moduleName string, str string) string { + if strings.HasPrefix(str, moduleName+".") { + return strings.TrimPrefix(str, moduleName+".") + } + return str + }, } func schemaType(t schema.Type) string { @@ -643,30 +680,25 @@ func writeSchemaErrors(config moduleconfig.ModuleConfig, errors []*schema.Error) return os.WriteFile(config.Abs().Errors, elBytes, 0600) } -func getLocalSumTypes(module *schema.Module) []goSumType { +func getLocalSumTypes(module *schema.Module, nativeNames NativeNames) []goSumType { sumTypes := make(map[string]goSumType) for _, d := range module.Decls { - if e, ok := d.(*schema.Enum); ok && !e.IsValueEnum() { - variants := make([]goSumTypeVariant, 0, len(e.Variants)) - for _, v := range e.Variants { - variants = append(variants, goSumTypeVariant{ //nolint:forcetypeassert - Name: v.Name, - }) - } - sumTypes[e.Name] = goSumType{ - Discriminator: e.Name, - Variants: variants, - } + e, ok := d.(*schema.Enum) + if !ok { + continue + } + if e.IsValueEnum() { + continue + } + if st, ok := getGoSumType(e, nativeNames).Get(); ok { + enumFqName := nativeNames[e] + sumTypes[enumFqName] = st } } - out := maps.Values(sumTypes) - slices.SortFunc(out, func(a, b goSumType) int { - return strings.Compare(a.Discriminator, b.Discriminator) - }) - return out + return maps.Values(sumTypes) } -func getRegisteredTypes(module *schema.Module, sch *schema.Schema, nativeNames NativeNames) ([]goSumType, []goExternalType) { +func getRegisteredTypes(module *schema.Module, sch *schema.Schema, nativeNames NativeNames) ([]goSumType, []goExternalType, error) { sumTypes := make(map[string]goSumType) goExternalTypes := make(map[string][]string) for _, d := range module.Decls { @@ -674,19 +706,11 @@ func getRegisteredTypes(module *schema.Module, sch *schema.Schema, nativeNames N case *schema.Enum: if d.IsValueEnum() { continue + } - variants := make([]goSumTypeVariant, 0, len(d.Variants)) - for _, v := range d.Variants { - variants = append(variants, goSumTypeVariant{ //nolint:forcetypeassert - Name: v.Name, - Type: nativeNames[v], - SchemaType: v.Value.(*schema.TypeValue).Value, - }) - } - stFqName := nativeNames[d] - sumTypes[stFqName] = goSumType{ - Discriminator: nativeNames[d], - Variants: variants, + if st, ok := getGoSumType(d, nativeNames).Get(); ok { + enumFqName := nativeNames[d] + sumTypes[enumFqName] = st } case *schema.TypeAlias: var fqName string @@ -698,7 +722,10 @@ func getRegisteredTypes(module *schema.Module, sch *schema.Schema, nativeNames N if fqName == "" { continue } - im, typ := getGoExternalType(fqName) + im, typ, err := getGoExternalType(fqName) + if err != nil { + return nil, nil, err + } if _, ok := goExternalTypes[im]; !ok { goExternalTypes[im] = []string{} } @@ -735,35 +762,44 @@ func getRegisteredTypes(module *schema.Module, sch *schema.Schema, nativeNames N Types: types, }) } - return out, externalTypes + return out, externalTypes, nil } -func getGoExternalType(fqName string) (_import string, _type string) { - // package and directory names are the same (dir=bar, pkg=bar): "github.com/foo/bar.A" - // package and directory names differ (dir=bar, pkg=baz): "github.com/foo/bar.baz.A" - parts := strings.Split(fqName, "/") - lastPart := parts[len(parts)-1] - pkgParts := strings.Split(lastPart, ".") - if len(pkgParts) < 2 { - panic("unexpected Go qualified name format: " + fqName) +func getGoSumType(enum *schema.Enum, nativeNames NativeNames) optional.Option[goSumType] { + if enum.IsValueEnum() { + return optional.None[goSumType]() + } + variants := make([]goSumTypeVariant, 0, len(enum.Variants)) + for _, v := range enum.Variants { + nativeName := nativeNames[v] + lastSlash := strings.LastIndex(nativeName, "/") + variants = append(variants, goSumTypeVariant{ //nolint:forcetypeassert + Name: nativeName[lastSlash+1:], + Type: nativeName, + SchemaType: v.Value.(*schema.TypeValue).Value, + }) } + stFqName := nativeNames[enum] + lastSlash := strings.LastIndex(stFqName, "/") + return optional.Some(goSumType{ + Discriminator: stFqName[lastSlash+1:], + Variants: variants, + fqName: stFqName, + }) +} - dirName := pkgParts[0] - typeName := pkgParts[len(pkgParts)-1] - pkg := pkgParts[len(pkgParts)-2] - - // translate the fqName to a valid import - // e.g.: - // "github.com/foo/bar.A" -> "github.com/foo/bar" - // "github.com/foo/bar.baz.A" -> "baz github.com/foo/bar" (aliased because package and directory path differ) - dirPath := strings.TrimSuffix(fqName, lastPart) + dirName - im := fmt.Sprintf("%q", dirPath) - if len(pkgParts) > 2 { - // import has an alias with the real package name: `import baz "github.com/foo/bar"` - im = fmt.Sprintf("%s %s", pkg, im) +func getGoExternalType(fqName string) (_import string, _type string, err error) { + im, err := goImportFromQualifiedName(fqName) + if err != nil { + return "", "", err } - - return im, fmt.Sprintf("%s.%s", pkg, typeName) + pkg := im[strings.LastIndex(im, "/")+1:] + if i := strings.LastIndex(im, " "); i != -1 { + // import has an alias and this will be the package + pkg = im[:i] + } + typeName := fqName[strings.LastIndex(fqName, ".")+1:] + return im, fmt.Sprintf("%s.%s", pkg, typeName), nil } type externalEnum struct { @@ -891,3 +927,24 @@ func goVerbFromQualifiedName(qualifiedName string) (goVerb, error) { MustImport: pkgPath, }, nil } + +// package and directory names are the same (dir=bar, pkg=bar): "github.com/foo/bar.A" => "github.com/foo/bar" +// package and directory names differ (dir=bar, pkg=baz): "github.com/foo/bar.baz.A" => "baz github.com/foo/bar" +func goImportFromQualifiedName(qualifiedName string) (string, error) { + lastDotIndex := strings.LastIndex(qualifiedName, ".") + if lastDotIndex == -1 { + return "", fmt.Errorf("invalid qualified type format %q", qualifiedName) + } + + pkgPath := qualifiedName[:lastDotIndex] + pkgName := path.Base(pkgPath) + + importAlias := "" + if lastDotIndex = strings.LastIndex(pkgName, "."); lastDotIndex != -1 { + pkgName = pkgName[lastDotIndex+1:] + pkgPath = pkgPath[:strings.LastIndex(pkgPath, ".")] + // package and path differ, so we need to alias the import + importAlias = pkgName + " " + } + return fmt.Sprintf("%s%q", importAlias, pkgPath), nil +} diff --git a/go-runtime/compile/schema.go b/go-runtime/compile/schema.go index f7f18e28a3..bc13d1ea7e 100644 --- a/go-runtime/compile/schema.go +++ b/go-runtime/compile/schema.go @@ -7,7 +7,6 @@ import ( "go/types" "path" "reflect" - "slices" "strconv" "strings" "unicode" @@ -35,8 +34,6 @@ var ( ftlConfigFuncPath = "github.com/TBD54566975/ftl/go-runtime/ftl.Config" ftlSecretFuncPath = "github.com/TBD54566975/ftl/go-runtime/ftl.Secret" //nolint:gosec ftlPostgresDBFuncPath = "github.com/TBD54566975/ftl/go-runtime/ftl.PostgresDatabase" - ftlUnitTypePath = "github.com/TBD54566975/ftl/go-runtime/ftl.Unit" - ftlOptionTypePath = "github.com/TBD54566975/ftl/go-runtime/ftl.Option" ftlTopicFuncPath = "github.com/TBD54566975/ftl/go-runtime/ftl.Topic" ftlSubscriptionFuncPath = "github.com/TBD54566975/ftl/go-runtime/ftl.Subscription" ftlTopicHandleTypeName = "TopicHandle" @@ -46,9 +43,6 @@ var ( // NativeNames is a map of top-level declarations to their native Go names. type NativeNames map[schema.Node]string -// enumInterfaces is a map of type enum names to the interface that variants must conform to. -type enumInterfaces map[string]*types.Interface - func noEndColumnErrorf(pos token.Pos, format string, args ...interface{}) *schema.Error { return tokenErrorf(pos, "", format, args...) } @@ -122,9 +116,6 @@ func legacyExtractModuleSchema(dir string, sch *schema.Schema, out *extract.Resu case *ast.CallExpr: visitCallExpr(pctx, node, stack) - case *ast.GenDecl: - visitGenDecl(pctx, node) - default: } return next() @@ -140,24 +131,14 @@ func legacyExtractModuleSchema(dir string, sch *schema.Schema, out *extract.Resu return nil } -// extractInitialDecls traverses the package's AST and extracts declarations needed up front (type aliases, enums and topics) -// -// This allows us to know if a type is a type alias or an enum regardless of ordering when visiting each ast node. -// - The decls get added to the pctx's module, nativeNames and enumInterfaces. -// - We only want to do a simple pass, so we do not resolve references to other types. This means the TypeAlias and Enum decls have Type = nil -// - This get's filled in with the next pass +// extractInitialDecls traverses the package's AST and extracts declarations needed up front (topics) // -// It also helps with topics because we need to know the stack when visiting a topic decl, but the subscription may occur first. +// We need to know the stack when visiting a topic decl, but the subscription may occur first. // In this case there is no way for the subscription to make the topic exported. func extractInitialDecls(pctx *parseContext) error { for _, file := range pctx.pkg.Syntax { err := goast.Visit(file, func(stack []ast.Node, next func() error) (err error) { switch node := stack[len(stack)-1].(type) { - case *ast.GenDecl: - if node.Tok == token.TYPE { - extractTypeDecl(pctx, node) - } - case *ast.CallExpr: _, fn := deref[*types.Func](pctx.pkg, node.Fun) if fn != nil && fn.FullName() == ftlTopicFuncPath { @@ -174,75 +155,6 @@ func extractInitialDecls(pctx *parseContext) error { return nil } -func extractTypeDecl(pctx *parseContext, node *ast.GenDecl) { - directives, parseErr := parseDirectives(node, fset, node.Doc) - if parseErr != nil { - // errors collected when visiting all nodes in the next pass - return - } - - for _, dir := range directives { - if len(node.Specs) != 1 { - // errors handled on next pass - continue - } - t, ok := node.Specs[0].(*ast.TypeSpec) - if !ok { - continue - } - - aType := pctx.pkg.Types.Scope().Lookup(t.Name.Name) - nativeName := aType.Pkg().Name() + "." + aType.Name() - - if ed, ok := dir.(*directiveEnum); ok { - typ := pctx.pkg.TypesInfo.TypeOf(t.Type) - switch underlying := typ.Underlying().(type) { - case *types.Basic: - enum := &schema.Enum{ - Pos: goPosToSchemaPos(node.Pos()), - Comments: parseComments(node.Doc), - Name: strcase.ToUpperCamel(t.Name.Name), - Type: nil, // nil until next pass, when we can visit the full type graph - Export: ed.IsExported(), - } - pctx.module.Decls = append(pctx.module.Decls, enum) - pctx.nativeNames[enum] = nativeName - case *types.Interface: - if underlying.NumMethods() == 0 { - pctx.errors.add(errorf(node, "enum discriminator %q must define at least one method", t.Name.Name)) - break - } - - hasExportedMethod := false - for i, n := 0, underlying.NumMethods(); i < n; i++ { - if underlying.Method(i).Exported() { - pctx.errors.add(noEndColumnErrorf(underlying.Method(i).Pos(), "enum discriminator %q cannot "+ - "contain exported methods", t.Name.Name)) - hasExportedMethod = true - } - } - if hasExportedMethod { - break - } - - enum := &schema.Enum{ - Pos: goPosToSchemaPos(node.Pos()), - Comments: parseComments(node.Doc), - Name: strcase.ToUpperCamel(t.Name.Name), - Export: ed.IsExported(), - } - if iTyp, ok := typ.(*types.Interface); ok { - pctx.nativeNames[enum] = nativeName - pctx.module.Decls = append(pctx.module.Decls, enum) - pctx.enumInterfaces[t.Name.Name] = iTyp - } else { - pctx.errors.add(errorf(node, "expected interface for type enum but got %q", typ)) - } - } - } - } -} - func extractStringLiteralArg(node *ast.CallExpr, argIndex int) (string, *schema.Error) { if argIndex >= len(node.Args) { return "", errorf(node, "expected string argument at index %d", argIndex) @@ -766,315 +678,6 @@ func goNodePosToSchemaPos(node ast.Node) (schema.Position, int) { return schema.Position{Filename: p.Filename, Line: p.Line, Column: p.Column, Offset: p.Offset}, fset.Position(node.End()).Column } -func visitGenDecl(pctx *parseContext, node *ast.GenDecl) { - switch node.Tok { - case token.TYPE: - directives, err := parseDirectives(node, fset, node.Doc) - if err != nil { - pctx.errors.add(err) - } - maybeVisitTypeEnumVariant(pctx, node, directives) - - if node.Doc == nil { - return - } - - for _, dir := range directives { - if _, ok := dir.(*directiveEnum); ok { - if len(node.Specs) != 1 { - pctx.errors.add(errorf(node, "error parsing ftl directive: expected "+ - "exactly one type declaration")) - return - } - if pctx.module.Name == "" { - pctx.module.Name = pctx.pkg.Name - } else if pctx.module.Name != pctx.pkg.Name && strings.TrimPrefix(pctx.pkg.Name, "ftl/") != pctx.module.Name { - pctx.errors.add(errorf(node, "ftl directive must be in the module package")) - return - } - if t, ok := node.Specs[0].(*ast.TypeSpec); ok { - isExported := false - if exportableDir, ok := dir.(exportable); ok { - isExported = exportableDir.IsExported() - } - // We have already collected enum and type alias declarations in extractTypeDecls - // On this second pass we can visit deeper and pull out the type information - typ := pctx.pkg.TypesInfo.TypeOf(t.Type) - if _, ok := dir.(*directiveEnum); ok { - enumOption, enumInterface := pctx.getEnumForTypeName(t.Name.Name) - enum, ok := enumOption.Get() - if !ok { - // This case can be reached if a type is both an enum and a typealias. - // Error is already reported in extractTypeDecls - return - } - switch typ.Underlying().(type) { - case *types.Basic: - if sType, ok := visitType(pctx, node.Pos(), typ, isExported).Get(); ok { - enum.Type = sType - } else { - pctx.errors.add(errorf(node, "unsupported type %q for value enum", - pctx.pkg.TypesInfo.TypeOf(t.Type).Underlying())) - } - case *types.Interface: - if !enumInterface.Ok() { - pctx.errors.add(errorf(node, "could not find interface for type enum")) - } - } - } else { - visitType(pctx, node.Pos(), pctx.pkg.TypesInfo.Defs[t.Name].Type(), isExported) - } - } - } - } - return - - case token.CONST: - var typ ast.Expr - for i, s := range node.Specs { - v, ok := s.(*ast.ValueSpec) - if !ok { - continue - } - // In an iota enum, only the first value has a type. - // Hydrate this to subsequent values so we can associate them with the enum. - if i == 0 && isIotaEnum(v) { - typ = v.Type - } else if v.Type == nil { - v.Type = typ - } - visitValueSpec(pctx, v) - } - return - - default: - return - } -} - -func maybeVisitTypeEnumVariant(pctx *parseContext, node *ast.GenDecl, directives []directive) { - if len(node.Specs) != 1 { - return - } - // `type NAME TYPE` e.g. type Scalar string - t, ok := node.Specs[0].(*ast.TypeSpec) - if !ok { - return - } - typ := pctx.pkg.TypesInfo.TypeOf(t.Type) - if _, ok := typ.Underlying().(*types.Interface); ok { - // Type enums should not count as variants of themselves - return - } - - enumVariant := &schema.EnumVariant{ - Pos: goPosToSchemaPos(node.Pos()), - Comments: parseComments(node.Doc), - Name: strcase.ToUpperCamel(t.Name.Name), - } - - matchedEnumNames := []string{} - - // iterate in a predictable way to make sure we are not flipflopping between builds of which type enum counts as first - allEnumNames := maps.Keys(pctx.enumInterfaces) - slices.Sort(allEnumNames) - for _, enumName := range allEnumNames { - interfaceNode := pctx.enumInterfaces[enumName] - - // If the type declared is an enum variant, then it must implement - // the interface of a type enum - named, ok := pctx.pkg.Types.Scope().Lookup(t.Name.Name).Type().(*types.Named) - if !ok { - continue - } - if !types.Implements(named, interfaceNode) { - continue - } - - enumOption, _ := pctx.getEnumForTypeName(enumName) - enum, ok := enumOption.Get() - if !ok { - pctx.errors.add(errorf(node, "could not find enum called %s", enumName)) - continue - } - - matchedEnumNames = append(matchedEnumNames, enumName) - if len(matchedEnumNames) > 1 { - continue - } - - if enum.VariantForName(enumVariant.Name).Ok() { - continue - } - - // If any directives on this node are exported, then the - // enum variant node is considered exported. Also, if the - // parent enum itself is exported, then all its variants - // should transitively also be exported. - isExported := enum.IsExported() - for _, dir := range directives { - if exportableDir, ok := dir.(exportable); ok { - isExported = exportableDir.IsExported() || isExported - } - } - vType, ok := visitTypeValue(pctx, named, t.Type, nil, isExported).Get() - if !ok { - pctx.errors.add(errorf(node, "unsupported type %q for type enum variant", named)) - continue - } - enumVariant.Value = vType - enum.Variants = append(enum.Variants, enumVariant) - pctx.nativeNames[enumVariant] = named.Obj().Pkg().Name() + "." + named.Obj().Name() - } - if len(matchedEnumNames) > 1 { - slices.Sort(matchedEnumNames) - pctx.errors.add(errorf(node, "type can not be a variant of more than 1 type enums (%s)", strings.Join(matchedEnumNames, ", "))) - } -} - -func visitTypeValue(pctx *parseContext, named *types.Named, tnode ast.Expr, index types.Type, isExported bool) optional.Option[*schema.TypeValue] { - switch typ := tnode.(type) { - // Selector expression e.g. ftl.Unit, ftl.Option, foo.Bar - case *ast.SelectorExpr: - var ident *ast.Ident - var ok bool - if ident, ok = typ.X.(*ast.Ident); !ok { - return optional.None[*schema.TypeValue]() - } - - for _, im := range maps.Values(pctx.pkg.Imports) { - if im.Name != ident.Name { - continue - } - switch im.ID + "." + typ.Sel.Name { - case "time.Time": - return optional.Some(&schema.TypeValue{ - Pos: goPosToSchemaPos(tnode.Pos()), - Value: &schema.Time{}, - }) - case ftlUnitTypePath: - return optional.Some(&schema.TypeValue{ - Pos: goPosToSchemaPos(tnode.Pos()), - Value: &schema.Unit{}, - }) - case ftlOptionTypePath: - if index == nil { - return optional.None[*schema.TypeValue]() - } - - if vt, ok := visitType(pctx, tnode.Pos(), index, isExported).Get(); ok { - return optional.Some(&schema.TypeValue{ - Pos: goPosToSchemaPos(tnode.Pos()), - Value: &schema.Optional{ - Pos: goPosToSchemaPos(tnode.Pos()), - Type: vt, - }, - }) - } - default: // Data ref - externalModuleName, ok := ftlModuleFromGoModule(im.ID).Get() - if !ok { - pctx.errors.add(errorf(tnode, "package %q is not in the ftl namespace", im.ID)) - return optional.None[*schema.TypeValue]() - } - return optional.Some(&schema.TypeValue{ - Pos: goPosToSchemaPos(tnode.Pos()), - Value: &schema.Ref{ - Pos: goPosToSchemaPos(tnode.Pos()), - Module: externalModuleName, - Name: typ.Sel.Name, - }, - }) - } - } - - case *ast.IndexExpr: // Generic type, e.g. ftl.Option[string] - if se, ok := typ.X.(*ast.SelectorExpr); ok { - return visitTypeValue(pctx, named, se, pctx.pkg.TypesInfo.TypeOf(typ.Index), isExported) - } - - default: - variantNode := pctx.pkg.TypesInfo.TypeOf(tnode) - if _, ok := variantNode.(*types.Struct); ok { - variantNode = named - } - if typ, ok := visitType(pctx, tnode.Pos(), variantNode, isExported).Get(); ok { - return optional.Some(&schema.TypeValue{Value: typ}) - } else { - pctx.errors.add(errorf(tnode, "unsupported type %q for type enum variant", named)) - } - } - - return optional.None[*schema.TypeValue]() -} - -func visitValueSpec(pctx *parseContext, node *ast.ValueSpec) { - var enum *schema.Enum - i, ok := node.Type.(*ast.Ident) - if !ok { - return - } - enumOption, enumInterface := pctx.getEnumForTypeName(i.Name) - enum, ok = enumOption.Get() - if !ok { - return - } - if enumInterface.Ok() { - pctx.errors.add(errorf(node, "cannot attach enum value to %s because it a type enum", enum.Name)) - return - } - maybeErrorOnInvalidEnumMixing(pctx, node, enum.Name) - - c, ok := pctx.pkg.TypesInfo.Defs[node.Names[0]].(*types.Const) - if !ok { - pctx.errors.add(errorf(node, "could not extract enum %s: expected exactly one variant name", enum.Name)) - return - } - - if value, ok := visitConst(pctx, c).Get(); ok { - variant := &schema.EnumVariant{ - Pos: goPosToSchemaPos(c.Pos()), - Comments: parseComments(node.Doc), - Name: strcase.ToUpperCamel(c.Id()), - Value: value, - } - enum.Variants = append(enum.Variants, variant) - } else { - pctx.errors.add(errorf(node, "unsupported type %q for enum variant %q", c.Type(), c.Name())) - } -} - -// maybeErrorOnInvalidEnumMixing ensures value enums are not set as variants of type enums. -// How this gets parsed: -// -// //ftl:enum -// type TypeEnum interface { typeEnum() } -// -// type BadValueEnum int -// -// // This line causes BadValueEnum to be parsed as a TypeEnum variant. At this point, we -// // cannot determine if BadValueEnum is intended to be a value enum, so we must treat it -// // as any other type enum variant. -// func (BadValueEnum) typeEnum() {} -// -// // This line will error because this is where we find out that BadValueEnum is intended -// // to be a value enum, but value enums cannot be variants of type enums. BadValueEnum -// // is not in pctx.enums. -// const A BadValueEnum = 1 -func maybeErrorOnInvalidEnumMixing(pctx *parseContext, node *ast.ValueSpec, enumName string) { - for _, decl := range pctx.module.Decls { - enum, ok := decl.(*schema.Enum) - if !ok { - continue - } - for _, variant := range enum.Variants { - if variant.Name == enumName { - pctx.errors.add(errorf(node, "cannot attach enum value to %s because it is a variant of type enum %s, not a value enum", enumName, enum.Name)) - } - } - } -} - func parseComments(doc *ast.CommentGroup) []string { comments := []string{} if doc := doc.Text(); doc != "" { @@ -1457,14 +1060,13 @@ func deref[T types.Object](pkg *packages.Package, node ast.Expr) (string, T) { } type parseContext struct { - pkg *packages.Package - pkgs []*packages.Package - module *schema.Module - nativeNames NativeNames - enumInterfaces enumInterfaces - errors errorSet - schema *schema.Schema - topicsByPos map[schema.Position]*schema.Topic + pkg *packages.Package + pkgs []*packages.Package + module *schema.Module + nativeNames NativeNames + errors errorSet + schema *schema.Schema + topicsByPos map[schema.Position]*schema.Topic } func newParseContext(pkg *packages.Package, pkgs []*packages.Package, sch *schema.Schema, out *extract.Result) *parseContext { @@ -1472,14 +1074,13 @@ func newParseContext(pkg *packages.Package, pkgs []*packages.Package, sch *schem out.NativeNames = NativeNames{} } return &parseContext{ - pkg: pkg, - pkgs: pkgs, - module: out.Module, - nativeNames: out.NativeNames, - enumInterfaces: enumInterfaces{}, - errors: errorSet{}, - schema: sch, - topicsByPos: map[schema.Position]*schema.Topic{}, + pkg: pkg, + pkgs: pkgs, + module: out.Module, + nativeNames: out.NativeNames, + errors: errorSet{}, + schema: sch, + topicsByPos: map[schema.Position]*schema.Topic{}, } } @@ -1516,27 +1117,6 @@ func (p *parseContext) isPathInPkg(path string) bool { return strings.HasPrefix(path, p.pkg.PkgPath+"/") } -// getEnumForTypeName returns the enum and interface for a given type name. -func (p *parseContext) getEnumForTypeName(name string) (optional.Option[*schema.Enum], optional.Option[*types.Interface]) { - aDecl, ok := p.getDeclForTypeName(name).Get() - if !ok { - return optional.None[*schema.Enum](), optional.None[*types.Interface]() - } - decl, ok := aDecl.(*schema.Enum) - if !ok { - return optional.None[*schema.Enum](), optional.None[*types.Interface]() - } - nativeName, ok := p.nativeNames[decl] - if !ok { - return optional.None[*schema.Enum](), optional.None[*types.Interface]() - } - enumInterface, isTypeEnum := p.enumInterfaces[strings.Split(nativeName, ".")[1]] - if isTypeEnum { - return optional.Some(decl), optional.Some(enumInterface) - } - return optional.Some(decl), optional.None[*types.Interface]() -} - func (p *parseContext) getDeclForTypeName(name string) optional.Option[schema.Decl] { for _, decl := range p.module.Decls { nativeName, ok := p.nativeNames[decl] @@ -1600,19 +1180,3 @@ func tokenFileContainsPos(f *token.File, pos token.Pos) bool { base := f.Base() return base <= p && p < base+f.Size() } - -func isIotaEnum(node ast.Node) bool { - switch t := node.(type) { - case *ast.ValueSpec: - if len(t.Values) != 1 { - return false - } - return isIotaEnum(t.Values[0]) - case *ast.Ident: - return t.Name == "iota" - case *ast.BinaryExpr: - return isIotaEnum(t.X) || isIotaEnum(t.Y) - default: - return false - } -} diff --git a/go-runtime/compile/schema_test.go b/go-runtime/compile/schema_test.go index cf2728c3bc..597f6eae0e 100644 --- a/go-runtime/compile/schema_test.go +++ b/go-runtime/compile/schema_test.go @@ -67,19 +67,19 @@ func TestExtractModuleSchema(t *testing.T) { } export enum Color: String { - Red = "Red" Blue = "Blue" Green = "Green" + Red = "Red" Yellow = "Yellow" } // Comments about ColorInt. enum ColorInt: Int { - // RedInt is a color. - RedInt = 0 BlueInt = 1 // GreenInt is also a color. GreenInt = 2 + // RedInt is a color. + RedInt = 0 YellowInt = 3 } @@ -96,15 +96,15 @@ func TestExtractModuleSchema(t *testing.T) { } enum SimpleIota: Int { - Zero = 0 One = 1 Two = 2 + Zero = 0 } enum TypeEnum { - Option String? - InlineStruct one.InlineStruct AliasedStruct one.UnderlyingStruct + InlineStruct one.InlineStruct + Option String? ValueEnum one.ColorInt } @@ -202,15 +202,15 @@ func TestExtractModuleSchemaTwo(t *testing.T) { +typemap go "github.com/TBD54566975/ftl/go-runtime/compile/testdata.lib.NonFTLType" export enum TwoEnum: String { - Red = "Red" Blue = "Blue" Green = "Green" + Red = "Red" } export enum TypeEnum { - Scalar String - List [String] Exported two.Exported + List [String] + Scalar String WithoutDirective two.WithoutDirective } @@ -321,17 +321,18 @@ func TestExtractModuleSchemaNamedTypes(t *testing.T) { // UserSource, testing that defining an enum after struct works export enum UserSource: String { - Magazine = "magazine" - Friend = "friend" Ad = "ad" + Friend = "friend" + Magazine = "magazine" } // UserState, testing that defining an enum before struct works export enum UserState: String { - Onboarded = "onboarded" - Registered = "registered" Active = "active" Inactive = "inactive" + // Out of order + Onboarded = "onboarded" + Registered = "registered" } export data User { @@ -363,8 +364,21 @@ func TestExtractModuleSchemaParent(t *testing.T) { expected := `module parent { export typealias ChildAlias String + export enum ChildTypeEnum { + List [String] + Scalar String + } + + export enum ChildValueEnum: Int { + A = 0 + B = 1 + C = 2 + } + export data ChildStruct { name parent.ChildAlias? + valueEnum parent.ChildValueEnum + typeEnum parent.ChildTypeEnum } data Resp { @@ -548,6 +562,7 @@ func TestErrorReporting(t *testing.T) { `22:14-44: duplicate database declaration at 21:14-44`, `25:2-10: unsupported type "error" for field "BadParam"`, `28:2-17: unsupported type "uint64" for field "AnotherBadParam"`, + `31:2-13: enum variant "SameVariant" conflicts with existing enum variant of "EnumVariantConflictParent" at "196:2"`, `31:3-3: unexpected directive "ftl:export" attached for verb, did you mean to use '//ftl:verb export' instead?`, `37:36-36: unsupported request type "ftl/failing.Request"`, `37:50-50: unsupported response type "ftl/failing.Response"`, @@ -576,19 +591,19 @@ func TestErrorReporting(t *testing.T) { `79:63-63: second result must not be ftl.Unit`, // `86:1-2: duplicate declaration of "WrongResponse" at 79:6`, TODO: fix this `90:3-3: unexpected directive "ftl:verb"`, - `104:2-24: cannot attach enum value to BadValueEnum because it is a variant of type enum TypeEnum, not a value enum`, - `111:2-41: cannot attach enum value to BadValueEnumOrderDoesntMatter because it is a variant of type enum TypeEnum, not a value enum`, + `99:6-18: "BadValueEnum" is a value enum and cannot be tagged as a variant of type enum "TypeEnum" directly`, + `108:6-35: "BadValueEnumOrderDoesntMatter" is a value enum and cannot be tagged as a variant of type enum "TypeEnum" directly`, `124:21-60: config and secret names must be valid identifiers`, `130:1-1: schema declaration contains conflicting directives`, `130:1-26: only one directive expected when directive "ftl:enum" is present, found multiple`, - `146:1-35: type can not be a variant of more than 1 type enums (TypeEnum1, TypeEnum2)`, - `152:27-27: enum discriminator "TypeEnum3" cannot contain exported methods`, - `155:1-35: enum discriminator "NoMethodsTypeEnum" must define at least one method`, + `152:6-45: enum discriminator "TypeEnum3" cannot contain exported methods`, + `155:6-35: enum discriminator "NoMethodsTypeEnum" must define at least one method`, `167:3-14: unexpected token "d"`, `174:2-62: can not publish directly to topics in other modules`, `175:9-26: can not call verbs in other modules directly: use ftl.Call(…) instead`, `180:2-12: struct field unexported must be exported by starting with an uppercase letter`, `184:6-6: unsupported type "ftl/failing/child.BadChildStruct" for field "child"`, + `189:6-6: duplicate Data declaration for "failing.Redeclared" in "ftl/failing"; already declared in "ftl/failing/child"`, } assert.Equal(t, expected, actual) } diff --git a/go-runtime/compile/testdata/failing/child/child.go b/go-runtime/compile/testdata/failing/child/child.go index 581697e91d..4e35f68aa0 100644 --- a/go-runtime/compile/testdata/failing/child/child.go +++ b/go-runtime/compile/testdata/failing/child/child.go @@ -19,3 +19,14 @@ type WrongMappingExternal lib.NonFTLType //ftl:typemap go "github.com/TBD54566975/ftl/go-runtime/compile/testdata.lib.NonFTLType" //ftl:typemap go "github.com/TBD54566975/ftl/go-runtime/compile/testdata.lib.NonFTLType" type MultipleMappings lib.NonFTLType + +//ftl:data +type Redeclared struct { +} + +//ftl:enum +type EnumVariantConflictChild int + +const ( + SameVariant EnumVariantConflictChild = iota +) diff --git a/go-runtime/compile/testdata/failing/failing.go b/go-runtime/compile/testdata/failing/failing.go index b3f335208a..b802a369e6 100644 --- a/go-runtime/compile/testdata/failing/failing.go +++ b/go-runtime/compile/testdata/failing/failing.go @@ -184,3 +184,14 @@ type UnexportedFieldStruct struct { type BadChildField struct { Child child.BadChildStruct } + +//ftl:data +type Redeclared struct { +} + +//ftl:enum +type EnumVariantConflictParent int + +const ( + SameVariant EnumVariantConflictParent = iota +) diff --git a/go-runtime/compile/testdata/parent/child/child.go b/go-runtime/compile/testdata/parent/child/child.go index db88b54df4..f82e428ebc 100644 --- a/go-runtime/compile/testdata/parent/child/child.go +++ b/go-runtime/compile/testdata/parent/child/child.go @@ -7,7 +7,9 @@ import ( ) type ChildStruct struct { - Name ftl.Option[ChildAlias] + Name ftl.Option[ChildAlias] + ValueEnum ChildValueEnum + TypeEnum ChildTypeEnum } type ChildAlias string @@ -19,3 +21,23 @@ type Resp struct { func ChildVerb(ctx context.Context) (Resp, error) { return Resp{}, nil } + +type ChildValueEnum int + +const ( + A ChildValueEnum = iota + B + C +) + +type ChildTypeEnum interface { + tag() +} + +type Scalar string + +func (Scalar) tag() {} + +type List []string + +func (List) tag() {} diff --git a/go-runtime/ftl/testdata/go/typeregistry/go.mod b/go-runtime/ftl/testdata/go/typeregistry/go.mod index 3d0103d67b..69b5f956c0 100644 --- a/go-runtime/ftl/testdata/go/typeregistry/go.mod +++ b/go-runtime/ftl/testdata/go/typeregistry/go.mod @@ -59,4 +59,4 @@ require ( google.golang.org/protobuf v1.34.2 // indirect ) -replace github.com/TBD54566975/ftl => ../../../../.. +replace github.com/TBD54566975/ftl => ./../../../../.. diff --git a/go-runtime/ftl/testdata/go/typeregistry/subpackage/subpackage.go b/go-runtime/ftl/testdata/go/typeregistry/subpackage/subpackage.go index 7410cbc3a6..ebac79c5b1 100644 --- a/go-runtime/ftl/testdata/go/typeregistry/subpackage/subpackage.go +++ b/go-runtime/ftl/testdata/go/typeregistry/subpackage/subpackage.go @@ -1,12 +1,5 @@ package subpackage -import ( - "context" - "ftl/builtin" - - "github.com/TBD54566975/ftl/go-runtime/ftl" // Import the FTL SDK. -) - //ftl:enum type StringsTypeEnum interface { tag() @@ -25,18 +18,3 @@ type Object struct { } func (Object) tag() {} - -type EchoRequest struct { - Strings StringsTypeEnum -} - -type EchoResponse struct { - Strings StringsTypeEnum -} - -//ftl:ingress POST /echo -func Echo(ctx context.Context, req builtin.HttpRequest[EchoRequest]) (builtin.HttpResponse[EchoResponse, string], error) { - return builtin.HttpResponse[EchoResponse, string]{ - Body: ftl.Some(EchoResponse{Strings: req.Body.Strings}), - }, nil -} diff --git a/go-runtime/ftl/testdata/go/typeregistry/typeregistry.go b/go-runtime/ftl/testdata/go/typeregistry/typeregistry.go index c392e799ec..f0d6dfd5b0 100644 --- a/go-runtime/ftl/testdata/go/typeregistry/typeregistry.go +++ b/go-runtime/ftl/testdata/go/typeregistry/typeregistry.go @@ -3,35 +3,17 @@ package typeregistry import ( "context" "ftl/builtin" + "ftl/typeregistry/subpackage" "github.com/TBD54566975/ftl/go-runtime/ftl" // Import the FTL SDK. ) -//ftl:enum -type StringsTypeEnum interface { - tag() -} - -type Single string - -func (Single) tag() {} - -type List []string - -func (List) tag() {} - -type Object struct { - S string -} - -func (Object) tag() {} - type EchoRequest struct { - Strings StringsTypeEnum + Strings subpackage.StringsTypeEnum } type EchoResponse struct { - Strings StringsTypeEnum + Strings subpackage.StringsTypeEnum } //ftl:ingress POST /echo diff --git a/go-runtime/ftl/testdata/go/typeregistry/typeregistry_test.go b/go-runtime/ftl/testdata/go/typeregistry/typeregistry_test.go index 90286a5947..0d7878ff90 100644 --- a/go-runtime/ftl/testdata/go/typeregistry/typeregistry_test.go +++ b/go-runtime/ftl/testdata/go/typeregistry/typeregistry_test.go @@ -2,6 +2,7 @@ package typeregistry import ( "ftl/builtin" + "ftl/typeregistry/subpackage" "testing" "github.com/TBD54566975/ftl/go-runtime/encoding" @@ -13,19 +14,19 @@ import ( func TestIngress(t *testing.T) { testCases := []struct { Name string - Input StringsTypeEnum + Input subpackage.StringsTypeEnum }{ { Name: "List", - Input: List([]string{"asdf", "qwerty"}), + Input: subpackage.List([]string{"asdf", "qwerty"}), }, { Name: "Single", - Input: Single("asdf"), + Input: subpackage.Single("asdf"), }, { Name: "Object", - Input: Object{S: "asdf"}, + Input: subpackage.Object{S: "asdf"}, }, } @@ -47,28 +48,28 @@ func TestIngress(t *testing.T) { func TestEncoding(t *testing.T) { testCases := []struct { Name string - Input StringsTypeEnum + Input subpackage.StringsTypeEnum Encoded string }{ { Name: "List", - Input: List([]string{"asdf", "qwerty"}), + Input: subpackage.List([]string{"asdf", "qwerty"}), Encoded: `{"input":{"name":"List","value":["asdf","qwerty"]}}`, }, { Name: "Single", - Input: Single("asdf"), + Input: subpackage.Single("asdf"), Encoded: `{"input":{"name":"Single","value":"asdf"}}`, }, { Name: "Object", - Input: Object{S: "asdf"}, + Input: subpackage.Object{S: "asdf"}, Encoded: `{"input":{"name":"Object","value":{"s":"asdf"}}}`, }, } type jsonObj struct { - Input StringsTypeEnum + Input subpackage.StringsTypeEnum } for _, test := range testCases { diff --git a/go-runtime/ftl/testdata/go/typeregistry/types.ftl.go b/go-runtime/ftl/testdata/go/typeregistry/types.ftl.go deleted file mode 100644 index 4fb4847817..0000000000 --- a/go-runtime/ftl/testdata/go/typeregistry/types.ftl.go +++ /dev/null @@ -1,14 +0,0 @@ -// Code generated by FTL. DO NOT EDIT. -package typeregistry - -import "github.com/TBD54566975/ftl/go-runtime/ftl/reflection" - -func init() { - reflection.Register( - reflection.SumType[StringsTypeEnum]( - *new(Single), - *new(List), - *new(Object), - ), - ) -} diff --git a/go-runtime/schema/common/common.go b/go-runtime/schema/common/common.go index 46b3ac1c06..377e839035 100644 --- a/go-runtime/schema/common/common.go +++ b/go-runtime/schema/common/common.go @@ -344,7 +344,7 @@ func extractRef(pass *analysis.Pass, pos token.Pos, named *types.Named) optional if isLocalRef(pass, ref) { // mark this local reference to ensure its underlying schema type is hydrated by the appropriate extractor and // included in the schema - markNeedsExtraction(pass, named.Obj()) + MarkNeedsExtraction(pass, named.Obj()) } return optional.Some[schema.Type](ref) @@ -395,7 +395,7 @@ func ExtractTypeForNode(pass *analysis.Pass, obj types.Object, node ast.Node, in if im.Name() != ident.Name { continue } - switch im.Path() /*"." + typ.Sel.Name */ { + switch im.Path() + "." + typ.Sel.Name { case "time.Time": return optional.Some[schema.Type](&schema.Time{}) case FtlUnitTypePath: @@ -404,7 +404,14 @@ func ExtractTypeForNode(pass *analysis.Pass, obj types.Object, node ast.Node, in if index == nil { return optional.None[schema.Type]() } - return ExtractType(pass, node.Pos(), index) + if underlying, ok := ExtractType(pass, node.Pos(), index).Get(); ok { + return optional.Some[schema.Type](&schema.Optional{ + Pos: GoPosToSchemaPos(pass.Fset, node.Pos()), + Type: underlying, + }) + } + return optional.None[schema.Type]() + default: // Data ref if strings.HasPrefix(im.Path(), pass.Pkg.Path()) { // subpackage, same module @@ -435,11 +442,11 @@ func ExtractTypeForNode(pass *analysis.Pass, obj types.Object, node ast.Node, in } default: - variantNode := GetTypeInfoForNode(node, pass.TypesInfo) - if _, ok := variantNode.(*types.Struct); ok { - variantNode = obj.Type() + tnode := GetTypeInfoForNode(node, pass.TypesInfo) + if _, ok := tnode.(*types.Struct); ok { + tnode = obj.Type() } - return ExtractType(pass, node.Pos(), variantNode) + return ExtractType(pass, node.Pos(), tnode) } return optional.None[schema.Type]() @@ -457,14 +464,6 @@ func IsSelfReference(pass *analysis.Pass, obj types.Object, t schema.Type) bool return ref.Module == moduleName && strcase.ToUpperCamel(obj.Name()) == ref.Name } -func isLocalRef(pass *analysis.Pass, ref *schema.Ref) bool { - moduleName, err := FtlModuleFromGoPackage(pass.Pkg.Path()) - if err != nil { - return false - } - return ref.Module == "" || ref.Module == moduleName -} - func GetNativeName(obj types.Object) string { fqName := obj.Pkg().Path() if parts := strings.Split(obj.Pkg().Path(), "/"); parts[len(parts)-1] != obj.Pkg().Name() { @@ -476,3 +475,20 @@ func GetNativeName(obj types.Object) string { func IsExternalType(obj types.Object) bool { return !strings.HasPrefix(obj.Pkg().Path(), "ftl/") } + +func GetDeclTypeName(d schema.Decl) string { + typeStr := reflect.TypeOf(d).String() + lastDotIndex := strings.LastIndex(typeStr, ".") + if lastDotIndex == -1 { + return typeStr + } + return typeStr[lastDotIndex+1:] +} + +func isLocalRef(pass *analysis.Pass, ref *schema.Ref) bool { + moduleName, err := FtlModuleFromGoPackage(pass.Pkg.Path()) + if err != nil { + return false + } + return ref.Module == "" || ref.Module == moduleName +} diff --git a/go-runtime/schema/common/fact.go b/go-runtime/schema/common/fact.go index 3e8cd91892..e7b94c746c 100644 --- a/go-runtime/schema/common/fact.go +++ b/go-runtime/schema/common/fact.go @@ -4,17 +4,16 @@ import ( "go/types" "reflect" - "github.com/alecthomas/types/optional" - "github.com/TBD54566975/ftl/backend/schema" "github.com/TBD54566975/golang-tools/go/analysis" + "github.com/alecthomas/types/optional" ) // SchemaFact is a fact that associates a schema node with a Go object. type SchemaFact interface { analysis.Fact - Set(v SchemaFactValue) - Get() SchemaFactValue + Add(v SchemaFactValue) + Get() []SchemaFactValue } // DefaultFact should be used as the base type for all schema facts. Each @@ -25,12 +24,17 @@ type SchemaFact interface { // // type Fact = common.DefaultFact[struct{}] type DefaultFact[T any] struct { - value SchemaFactValue + value []SchemaFactValue } -func (*DefaultFact[T]) AFact() {} -func (t *DefaultFact[T]) Set(v SchemaFactValue) { t.value = v } -func (t *DefaultFact[T]) Get() SchemaFactValue { return t.value } +func (*DefaultFact[T]) AFact() {} +func (t *DefaultFact[T]) Add(v SchemaFactValue) { + if t.value == nil { + t.value = []SchemaFactValue{} + } + t.value = append(t.value, v) +} +func (t *DefaultFact[T]) Get() []SchemaFactValue { return t.value } // SchemaFactValue is the value of a SchemaFact. type SchemaFactValue interface { @@ -44,6 +48,32 @@ type ExtractedDecl struct { func (*ExtractedDecl) schemaFactValue() {} +// MaybeTypeEnum is a fact for marking an object as a possible type enum discriminator. +type MaybeTypeEnum struct { + Enum *schema.Enum +} + +func (*MaybeTypeEnum) schemaFactValue() {} + +// MaybeTypeEnumVariant is a fact for marking an object as a possible type enum variant. +type MaybeTypeEnumVariant struct { + GetValue func(pass *analysis.Pass) optional.Option[*schema.TypeValue] + // the parent enum + Parent types.Object + // this variant + Variant *schema.EnumVariant +} + +func (*MaybeTypeEnumVariant) schemaFactValue() {} + +// MaybeValueEnumVariant is a fact for marking an object as a possible value enum variant. +type MaybeValueEnumVariant struct { + // this variant + Variant *schema.EnumVariant +} + +func (*MaybeValueEnumVariant) schemaFactValue() {} + // ExtractedMetadata is a fact for associating an object with extracted schema metadata. type ExtractedMetadata struct { Type schema.Decl @@ -64,122 +94,156 @@ type FailedExtraction struct{} func (*FailedExtraction) schemaFactValue() {} -// MarkSchemaDecl marks the given object as having been extracted to the given schema node. +// MarkSchemaDecl marks the given object as having been extracted to the given schema decl. func MarkSchemaDecl(pass *analysis.Pass, obj types.Object, decl schema.Decl) { - fact := newFact(pass) - fact.Set(&ExtractedDecl{Decl: decl}) + fact := newFact(pass, obj) + fact.Add(&ExtractedDecl{Decl: decl}) pass.ExportObjectFact(obj, fact) } // MarkFailedExtraction marks the given object as having failed extraction. func MarkFailedExtraction(pass *analysis.Pass, obj types.Object) { - fact := newFact(pass) - fact.Set(&FailedExtraction{}) + fact := newFact(pass, obj) + fact.Add(&FailedExtraction{}) pass.ExportObjectFact(obj, fact) } func MarkMetadata(pass *analysis.Pass, obj types.Object, md *ExtractedMetadata) { - fact := newFact(pass) - fact.Set(md) + fact := newFact(pass, obj) + fact.Add(md) + pass.ExportObjectFact(obj, fact) +} + +// MarkNeedsExtraction marks the given object as needing extraction. +func MarkNeedsExtraction(pass *analysis.Pass, obj types.Object) { + fact := newFact(pass, obj) + fact.Add(&NeedsExtraction{}) pass.ExportObjectFact(obj, fact) } -// markNeedsExtraction marks the given object as needing extraction. -func markNeedsExtraction(pass *analysis.Pass, obj types.Object) { - fact := newFact(pass) - fact.Set(&NeedsExtraction{}) +// MarkMaybeTypeEnumVariant marks the given object as a possible type enum variant. +func MarkMaybeTypeEnumVariant(pass *analysis.Pass, obj types.Object, variant *schema.EnumVariant, + parent types.Object, valueFunc func(pass *analysis.Pass) optional.Option[*schema.TypeValue]) { + fact := newFact(pass, obj) + fact.Add(&MaybeTypeEnumVariant{Parent: parent, Variant: variant, GetValue: valueFunc}) pass.ExportObjectFact(obj, fact) } -// MergeAllFacts merges schema facts inclusive of all available results and the present pass facts. +// MarkMaybeValueEnumVariant marks the given object as a possible value enum variant. +func MarkMaybeValueEnumVariant(pass *analysis.Pass, obj types.Object, variant *schema.EnumVariant) { + fact := newFact(pass, obj) + fact.Add(&MaybeValueEnumVariant{Variant: variant}) + pass.ExportObjectFact(obj, fact) +} + +// MarkMaybeTypeEnum marks the given object as a possible type enum discriminator. +func MarkMaybeTypeEnum(pass *analysis.Pass, obj types.Object, enum *schema.Enum) { + fact := newFact(pass, obj) + fact.Add(&MaybeTypeEnum{Enum: enum}) + pass.ExportObjectFact(obj, fact) +} + +// GetAllFactsExtractionStatus merges schema facts inclusive of all available results and the present pass facts. +// For a given object, it provides the current extraction status. // -// If multiple facts are present for the same object, the facts will be prioritized by type: +// If multiple extraction facts are present for the same object, the facts will be prioritized by type: // 1. ExtractedDecl // 2. FailedExtraction -// 4. NeedsExtraction +// 3. NeedsExtraction // -// ExtractedMetadata facts are ignored. -func MergeAllFacts(pass *analysis.Pass) map[types.Object]SchemaFact { - facts := make(map[types.Object]SchemaFact) - for _, fact := range allFactsForPass(pass) { - f, ok := fact.Fact.(SchemaFact) +// All other fact types are ignored. +func GetAllFactsExtractionStatus(pass *analysis.Pass) map[types.Object]SchemaFactValue { + facts := make(map[types.Object]SchemaFactValue) + for _, fact := range allFacts(pass) { + sf, ok := fact.Fact.(SchemaFact) if !ok { continue } - // skip metadata facts - if _, ok = f.Get().(*ExtractedMetadata); ok { - continue - } - // prioritize facts by type // // e.g. if one extractor marked an object as needing extraction and another extractor marked it with the // completed extraction, we should prioritize the completed extraction. - prioritize := func(f SchemaFact) int { - switch f.Get().(type) { - case *ExtractedDecl: + prioritize := func(v SchemaFactValue) int { + switch v.(type) { + case *NeedsExtraction: return 1 case *FailedExtraction: return 2 - case *NeedsExtraction: + case *ExtractedDecl: return 3 default: - return 4 + return -1 } } - existing, ok := facts[fact.Object] - if !ok || prioritize(f) < prioritize(existing) { - facts[fact.Object] = f + for _, f := range sf.Get() { + newPriority := prioritize(f) + if newPriority == -1 { + continue + } + + existing, ok := facts[fact.Object] + existingPriority := prioritize(existing) + if !ok || newPriority > existingPriority { + facts[fact.Object] = f + } } } return facts } -func GetFact[T SchemaFactValue](facts []SchemaFact) optional.Option[T] { - for _, fact := range facts { - if f, ok := fact.Get().(T); ok { - return optional.Some(f) - } - } - return optional.None[T]() +// GetAllFacts returns all facts of the provided type marked on objects, across the current pass and results from +// prior passes. If multiple of the same fact type are marked on a single object, the first fact is returned. +func GetAllFacts[T SchemaFactValue](pass *analysis.Pass) map[types.Object]T { + return getFactsScoped[T](allFacts(pass)) } -// GetFactsForObject returns all facts marked on the object. -func GetFactsForObject[T SchemaFactValue](pass *analysis.Pass, obj types.Object) []T { - var facts []T - for _, fact := range allFactsForPass(pass) { - if fact.Object != obj { - continue - } +// GetCurrentPassFacts returns all facts of the provided type marked on objects during the current pass. +// If multiple of the same fact type are marked on a single object, the first fact is returned. +func GetCurrentPassFacts[T SchemaFactValue](pass *analysis.Pass) map[types.Object]T { + return getFactsScoped[T](pass.AllObjectFacts()) +} + +func getFactsScoped[T SchemaFactValue](scope []analysis.ObjectFact) map[types.Object]T { + facts := make(map[types.Object]T) + for _, fact := range scope { sf, ok := fact.Fact.(SchemaFact) if !ok { continue } - if f, ok := sf.Get().(T); ok { - facts = append(facts, f) + + for _, f := range sf.Get() { + if t, ok := f.(T); ok { + facts[fact.Object] = t + } } } return facts } -func GetFacts[T SchemaFactValue](pass *analysis.Pass) map[types.Object]T { - facts := make(map[types.Object]T) - for _, fact := range allFactsForPass(pass) { +// GetFactForObject returns the first fact of the provided type marked on the object. +func GetFactForObject[T SchemaFactValue](pass *analysis.Pass, obj types.Object) optional.Option[T] { + for _, fact := range allFacts(pass) { + if fact.Object != obj { + continue + } sf, ok := fact.Fact.(SchemaFact) if !ok { continue } - if f, ok := sf.Get().(T); ok { - facts[fact.Object] = f + for _, f := range sf.Get() { + if f, ok := f.(T); ok { + return optional.Some(f) + } } } - return facts + return optional.None[T]() } -// GetFactForObject returns the first fact of the provided type marked on the object. -func GetFactForObject[T SchemaFactValue](pass *analysis.Pass, obj types.Object) optional.Option[T] { - for _, fact := range allFactsForPass(pass) { +// GetFactsForObject returns the all facts of the provided type marked on the object. +func GetFactsForObject[T SchemaFactValue](pass *analysis.Pass, obj types.Object) []T { + facts := []T{} + for _, fact := range allFacts(pass) { if fact.Object != obj { continue } @@ -187,14 +251,16 @@ func GetFactForObject[T SchemaFactValue](pass *analysis.Pass, obj types.Object) if !ok { continue } - if f, ok := sf.Get().(T); ok { - return optional.Some(f) + for _, f := range sf.Get() { + if f, ok := f.(T); ok { + facts = append(facts, f) + } } } - return optional.None[T]() + return facts } -func allFactsForPass(pass *analysis.Pass) []analysis.ObjectFact { +func allFacts(pass *analysis.Pass) []analysis.ObjectFact { var all []analysis.ObjectFact all = append(all, pass.AllObjectFacts()...) for _, result := range pass.ResultOf { @@ -207,7 +273,21 @@ func allFactsForPass(pass *analysis.Pass) []analysis.ObjectFact { return all } -func newFact(pass *analysis.Pass) SchemaFact { - factType := reflect.TypeOf(pass.Analyzer.FactTypes[0]).Elem() - return reflect.New(factType).Interface().(SchemaFact) //nolint:forcetypeassert +func newFact(pass *analysis.Pass, obj types.Object) SchemaFact { + existing := optional.None[SchemaFact]() + for _, fact := range pass.AllObjectFacts() { + if fact.Object != obj { + continue + } + if sf, ok := fact.Fact.(SchemaFact); ok { + existing = optional.Some(sf) + } + } + + fact, ok := existing.Get() + if !ok { + factType := reflect.TypeOf(pass.Analyzer.FactTypes[0]).Elem() + fact = reflect.New(factType).Interface().(SchemaFact) //nolint:forcetypeassert + } + return fact } diff --git a/go-runtime/schema/enum/analyzer.go b/go-runtime/schema/enum/analyzer.go new file mode 100644 index 0000000000..2a5edb5cf7 --- /dev/null +++ b/go-runtime/schema/enum/analyzer.go @@ -0,0 +1,119 @@ +package enum + +import ( + "go/ast" + "go/types" + "slices" + "strings" + + "github.com/alecthomas/types/optional" + + "github.com/TBD54566975/ftl/backend/schema" + "github.com/TBD54566975/ftl/backend/schema/strcase" + "github.com/TBD54566975/ftl/go-runtime/schema/common" + "github.com/TBD54566975/golang-tools/go/analysis" +) + +// Extractor extracts type aliases to the module schema. +var Extractor = common.NewDeclExtractor[*schema.Enum, *ast.TypeSpec]("typealias", Extract) + +func Extract(pass *analysis.Pass, node *ast.TypeSpec, obj types.Object) optional.Option[*schema.Enum] { + valueVariants := findValueEnumVariants(pass, obj) + if facts := common.GetFactsForObject[*common.MaybeTypeEnumVariant](pass, obj); len(facts) > 0 && len(valueVariants) > 0 { + for _, te := range facts { + common.TokenErrorf(pass, obj.Pos(), obj.Name(), "%q is a value enum and cannot be tagged as a variant of type enum %q directly", + obj.Name(), te.Parent.Name()) + } + } + + // type enum + if discriminator, ok := common.GetFactForObject[*common.MaybeTypeEnum](pass, obj).Get(); ok { + if len(valueVariants) > 0 { + common.Errorf(pass, node, "type %q cannot be both a type and value enum", obj.Name()) + return optional.None[*schema.Enum]() + } + + e := discriminator.Enum + e.Variants = findTypeValueVariants(pass, obj) + slices.SortFunc(e.Variants, func(a, b *schema.EnumVariant) int { + return strings.Compare(a.Name, b.Name) + }) + return optional.Some(e) + } + + // value enum + if len(valueVariants) == 0 { + return optional.None[*schema.Enum]() + } + + typ, ok := common.ExtractType(pass, node.Pos(), pass.TypesInfo.TypeOf(node.Type)).Get() + if !ok { + return optional.None[*schema.Enum]() + } + + e := &schema.Enum{ + Pos: common.GoPosToSchemaPos(pass.Fset, node.Pos()), + Name: strcase.ToUpperCamel(obj.Name()), + Variants: valueVariants, + Type: typ, + } + if md, ok := common.GetFactForObject[*common.ExtractedMetadata](pass, obj).Get(); ok { + e.Comments = md.Comments + e.Export = md.IsExported + } + + return optional.Some(e) + +} + +func findValueEnumVariants(pass *analysis.Pass, obj types.Object) []*schema.EnumVariant { + var variants []*schema.EnumVariant + for o, fact := range common.GetAllFacts[*common.MaybeValueEnumVariant](pass) { + if o.Type() == obj.Type() && validateVariant(pass, o, fact.Variant) { + variants = append(variants, fact.Variant) + } + } + slices.SortFunc(variants, func(a, b *schema.EnumVariant) int { + return strings.Compare(a.Name, b.Name) + }) + return variants +} + +func validateVariant(pass *analysis.Pass, obj types.Object, variant *schema.EnumVariant) bool { + for _, fact := range common.GetAllFacts[*common.ExtractedDecl](pass) { + existingEnum, ok := fact.Decl.(*schema.Enum) + if !ok { + continue + } + for _, existingVariant := range existingEnum.Variants { + if existingVariant.Name == variant.Name && common.GoPosToSchemaPos(pass.Fset, obj.Pos()) != existingVariant.Pos { + common.TokenErrorf(pass, obj.Pos(), obj.Name(), "enum variant %q conflicts with existing enum "+ + "variant of %q at %q", variant.Name, existingEnum.GetName(), existingVariant.Pos) + return false + } + } + } + return true +} + +func findTypeValueVariants(pass *analysis.Pass, obj types.Object) []*schema.EnumVariant { + var variants []*schema.EnumVariant + for vObj, fact := range common.GetAllFacts[*common.MaybeTypeEnumVariant](pass) { + if fact.Parent != obj { + continue + } + // extract variant type here rather than in the `typeenumvariant` extractor so that we only + // call `common.ExtractType` if the enum/variant is actually part of the schema. + // + // the call to common.ExtractType sometimes results in transitive extraction, which we don't want during + // the initial pass marking all *possible* variants, as some may never be used. + value, ok := fact.GetValue(pass).Get() + if !ok { + common.NoEndColumnErrorf(pass, vObj.Pos(), "invalid type for enum variant %q", fact.Variant.Name) + } + fact.Variant.Value = value + variants = append(variants, fact.Variant) + + } + return variants +} diff --git a/go-runtime/schema/extract.go b/go-runtime/schema/extract.go index 1e47075328..5b4d493185 100644 --- a/go-runtime/schema/extract.go +++ b/go-runtime/schema/extract.go @@ -4,6 +4,10 @@ import ( "fmt" "go/types" + "github.com/TBD54566975/ftl/go-runtime/schema/enum" + "github.com/TBD54566975/ftl/go-runtime/schema/typeenum" + "github.com/TBD54566975/ftl/go-runtime/schema/typeenumvariant" + "github.com/TBD54566975/ftl/go-runtime/schema/valueenumvariant" "github.com/alecthomas/types/optional" "golang.org/x/exp/maps" @@ -35,10 +39,22 @@ var Extractors = [][]*analysis.Analyzer{ { metadata.Extractor, }, + { + // must run before typeenumvariant.Extractor; typeenum.Extractor determines all possible discriminator + // interfaces and typeenumvariant.Extractor determines any types that implement these + typeenum.Extractor, + }, { typealias.Extractor, verb.Extractor, data.Extractor, + valueenumvariant.Extractor, + typeenumvariant.Extractor, + }, + { + // must run after valueenumvariant.Extractor and typeenumvariant.Extractor; + // visits a node and aggregates its enum variants if present + enum.Extractor, }, { transitive.Extractor, @@ -118,6 +134,8 @@ func combineAllPackageResults(results map[*analysis.Analyzer][]any, diagnostics } refResults := make(map[schema.RefKey]refResult) extractedDecls := make(map[schema.Decl]types.Object) + // for identifying duplicates + declKeys := make(map[string]types.Object) for _, r := range fResults { fr, ok := r.(finalize.Result) if !ok { @@ -134,7 +152,18 @@ func combineAllPackageResults(results map[*analysis.Analyzer][]any, diagnostics } } copyFailedRefs(refResults, fr.Failed) - maps.Copy(extractedDecls, fr.Extracted) + for decl, obj := range fr.Extracted { + if existing, ok := declKeys[decl.String()]; ok && existing != obj { + // decls redeclared in subpackage + combined.Errors = append(combined.Errors, schema.Errorf(decl.Position(), decl.Position().Column, + "duplicate %s declaration for %q in %q; already declared in %q", common.GetDeclTypeName(decl), + combined.Module.Name+"."+decl.GetName(), obj.Pkg().Path(), existing.Pkg().Path())) + continue + } + declKeys[decl.String()] = obj + extractedDecls[decl] = obj + } + maps.Copy(combined.NativeNames, fr.NativeNames) } combined.Module.AddDecls(maps.Keys(extractedDecls)) diff --git a/go-runtime/schema/finalize/analyzer.go b/go-runtime/schema/finalize/analyzer.go index 53f7eb09e9..8eefaa3f5d 100644 --- a/go-runtime/schema/finalize/analyzer.go +++ b/go-runtime/schema/finalize/analyzer.go @@ -32,6 +32,8 @@ type Result struct { Extracted map[schema.Decl]types.Object // Failed contains all objects that failed extraction. Failed map[schema.RefKey]types.Object + // Native names that can't be derived outside of the analysis pass. + NativeNames map[schema.Node]string } func Run(pass *analysis.Pass) (interface{}, error) { @@ -41,21 +43,34 @@ func Run(pass *analysis.Pass) (interface{}, error) { } extracted := make(map[schema.Decl]types.Object) failed := make(map[schema.RefKey]types.Object) - for obj, fact := range common.MergeAllFacts(pass) { - switch f := fact.Get().(type) { + // for identifying duplicates + declKeys := make(map[string]types.Object) + for obj, fact := range common.GetAllFactsExtractionStatus(pass) { + switch f := fact.(type) { case *common.ExtractedDecl: - if f.Decl != nil { + if existing, ok := declKeys[f.Decl.String()]; ok && existing != obj && obj.Pkg().Path() == pass.Pkg.Path() { + common.NoEndColumnErrorf(pass, obj.Pos(), "duplicate %s declaration for %q; already declared at %q", + common.GetDeclTypeName(f.Decl), moduleName+"."+f.Decl.GetName(), common.GoPosToSchemaPos(pass.Fset, existing.Pos())) + continue + } + if f.Decl != nil && pass.Pkg.Path() == obj.Pkg().Path() { extracted[f.Decl] = obj + declKeys[f.Decl.String()] = obj } case *common.FailedExtraction: failed[schema.RefKey{Module: moduleName, Name: strcase.ToUpperCamel(obj.Name())}] = obj } } + nativeNames := make(map[schema.Node]string) + for obj, fact := range common.GetAllFacts[*common.MaybeTypeEnumVariant](pass) { + nativeNames[fact.Variant] = common.GetNativeName(obj) + } return Result{ ModuleName: moduleName, ModuleComments: extractModuleComments(pass), Extracted: extracted, Failed: failed, + NativeNames: nativeNames, }, nil } diff --git a/go-runtime/schema/metadata/analyzer.go b/go-runtime/schema/metadata/analyzer.go index ae16ce138d..dca95b3fe2 100644 --- a/go-runtime/schema/metadata/analyzer.go +++ b/go-runtime/schema/metadata/analyzer.go @@ -25,6 +25,8 @@ func Extract(pass *analysis.Pass) (interface{}, error) { in := pass.ResultOf[inspect.Analyzer].(*inspector.Inspector) //nolint:forcetypeassert nodeFilter := []ast.Node{ (*ast.GenDecl)(nil), + (*ast.TypeSpec)(nil), + (*ast.ValueSpec)(nil), (*ast.FuncDecl)(nil), } in.Preorder(nodeFilter, func(n ast.Node) { @@ -32,6 +34,8 @@ func Extract(pass *analysis.Pass) (interface{}, error) { switch n := n.(type) { case *ast.TypeSpec: doc = n.Doc + case *ast.ValueSpec: + doc = n.Doc case *ast.GenDecl: doc = n.Doc if ts, ok := n.Specs[0].(*ast.TypeSpec); len(n.Specs) > 0 && ok { @@ -202,8 +206,8 @@ func canRepeatDirective(dir common.Directive) bool { // TODO: fix - this doesn't work for member functions. // // func getDuplicate(pass *analysis.Pass, name string, newMd *common.ExtractedMetadata) optional.Option[types.Object] { -// for obj, md := range common.GetFacts[*common.ExtractedMetadata](pass) { -// if reflect.TypeOf(md.Type) == reflect.TypeOf(newMd.Type) && obj.Name() == name { +// for obj, md := range common.GetAllFacts[*common.ExtractedMetadata](pass) { +// if reflect.TypeOf(md.Type) == reflect.TypeOf(newMd.Type) && obj.Ref() == name { // return optional.Some(obj) // } // } diff --git a/go-runtime/schema/transitive/analyzer.go b/go-runtime/schema/transitive/analyzer.go index c7b44d0d0b..35c99e1f00 100644 --- a/go-runtime/schema/transitive/analyzer.go +++ b/go-runtime/schema/transitive/analyzer.go @@ -30,27 +30,29 @@ type Fact = common.DefaultFact[Tag] // annotated with an FTL directive. func Extract(pass *analysis.Pass) (interface{}, error) { needsExtraction := sets.NewSet[types.Object]() - for obj, fact := range common.MergeAllFacts(pass) { - if _, ok := fact.Get().(*common.NeedsExtraction); ok { + for obj, fact := range common.GetAllFactsExtractionStatus(pass) { + if _, ok := fact.(*common.NeedsExtraction); ok { needsExtraction.Add(obj) } } + + visited := sets.NewSet[types.Object]() for !needsExtraction.IsEmpty() { extractTransitive(pass, needsExtraction) - needsExtraction = refreshNeedsExtraction(pass) + visited.Append(needsExtraction.ToSlice()...) + needsExtraction = refreshNeedsExtraction(pass, visited) } return common.NewExtractorResult(pass), nil } -func refreshNeedsExtraction(pass *analysis.Pass) sets.Set[types.Object] { +func refreshNeedsExtraction(pass *analysis.Pass, visited sets.Set[types.Object]) sets.Set[types.Object] { facts := sets.NewSet[types.Object]() - for _, fact := range pass.AllObjectFacts() { - f, ok := fact.Fact.(common.SchemaFact) - if !ok { + for obj := range common.GetCurrentPassFacts[*common.NeedsExtraction](pass) { + if visited.Contains(obj) { continue } - if _, ok := f.Get().(*common.NeedsExtraction); ok && fact.Object.Pkg().Path() == pass.Pkg.Path() { - facts.Add(fact.Object) + if obj.Pkg().Path() == pass.Pkg.Path() { + facts.Add(obj) } } return facts @@ -109,6 +111,12 @@ func inferDeclType(pass *analysis.Pass, node ast.Node, obj types.Object) optiona return optional.None[schema.Decl]() } if !common.IsSelfReference(pass, obj, t) { + // if this is a type alias and it has enum variants, infer to be a value enum + for o := range common.GetAllFacts[*common.MaybeValueEnumVariant](pass) { + if o.Type() == obj.Type() { + return optional.Some[schema.Decl](&schema.Enum{}) + } + } return optional.Some[schema.Decl](&schema.TypeAlias{}) } return optional.Some[schema.Decl](&schema.Data{}) diff --git a/go-runtime/schema/typeenum/analyzer.go b/go-runtime/schema/typeenum/analyzer.go new file mode 100644 index 0000000000..1dc6864c41 --- /dev/null +++ b/go-runtime/schema/typeenum/analyzer.go @@ -0,0 +1,71 @@ +package typeenum + +import ( + "go/ast" + "go/types" + + "github.com/TBD54566975/ftl/backend/schema" + "github.com/TBD54566975/ftl/backend/schema/strcase" + "github.com/TBD54566975/ftl/go-runtime/schema/common" + "github.com/TBD54566975/golang-tools/go/analysis" + "github.com/TBD54566975/golang-tools/go/analysis/passes/inspect" + "github.com/TBD54566975/golang-tools/go/ast/inspector" +) + +// Extractor extracts possible type enum discriminators. +// +// All named interfaces are marked as possible type enum discriminators and subsequent extractors determine if they are +// part of an enum. +var Extractor = common.NewExtractor("typeenum", (*Fact)(nil), Extract) + +type Tag struct{} // Tag uniquely identifies the fact type for this extractor. +type Fact = common.DefaultFact[Tag] + +func Extract(pass *analysis.Pass) (interface{}, error) { + in := pass.ResultOf[inspect.Analyzer].(*inspector.Inspector) //nolint:forcetypeassert + nodeFilter := []ast.Node{ + (*ast.TypeSpec)(nil), + } + in.Preorder(nodeFilter, func(n ast.Node) { + node := n.(*ast.TypeSpec) //nolint:forcetypeassert + + iType, ok := pass.TypesInfo.TypeOf(node.Type).Underlying().(*types.Interface) + if !ok { + return + } + + obj, ok := common.GetObjectForNode(pass.TypesInfo, node).Get() + if !ok { + return + } + + enum := &schema.Enum{ + Pos: common.GoPosToSchemaPos(pass.Fset, node.Pos()), + Name: strcase.ToUpperCamel(node.Name.Name), + } + if md, ok := common.GetFactForObject[*common.ExtractedMetadata](pass, obj).Get(); ok { + enum.Comments = md.Comments + enum.Export = md.IsExported + + if _, ok := md.Type.(*schema.Enum); ok { + if iType.NumMethods() == 0 { + common.Errorf(pass, node, "enum discriminator %q must define at least one method", node.Name.Name) + return + } + for i := range iType.NumMethods() { + m := iType.Method(i) + if m.Exported() { + common.Errorf(pass, node, "enum discriminator %q cannot contain exported methods", + node.Name.Name) + return + } + } + common.MarkNeedsExtraction(pass, obj) + } + } + if iType.NumMethods() > 0 { + common.MarkMaybeTypeEnum(pass, obj, enum) + } + }) + return common.NewExtractorResult(pass), nil +} diff --git a/go-runtime/schema/typeenumvariant/analyzer.go b/go-runtime/schema/typeenumvariant/analyzer.go new file mode 100644 index 0000000000..f855416bc3 --- /dev/null +++ b/go-runtime/schema/typeenumvariant/analyzer.go @@ -0,0 +1,75 @@ +package typeenumvariant + +import ( + "go/ast" + "go/types" + + "github.com/TBD54566975/ftl/backend/schema" + "github.com/TBD54566975/ftl/backend/schema/strcase" + "github.com/TBD54566975/ftl/go-runtime/schema/common" + "github.com/TBD54566975/golang-tools/go/analysis" + "github.com/TBD54566975/golang-tools/go/analysis/passes/inspect" + "github.com/TBD54566975/golang-tools/go/ast/inspector" + "github.com/alecthomas/types/optional" +) + +// Extractor extracts possible type enum variants. +var Extractor = common.NewExtractor("typeenumvariant", (*Fact)(nil), Extract) + +type Tag struct{} // Tag uniquely identifies the fact type for this extractor. +type Fact = common.DefaultFact[Tag] + +func Extract(pass *analysis.Pass) (interface{}, error) { + in := pass.ResultOf[inspect.Analyzer].(*inspector.Inspector) //nolint:forcetypeassert + nodeFilter := []ast.Node{ + (*ast.TypeSpec)(nil), + } + in.Preorder(nodeFilter, func(n ast.Node) { + node := n.(*ast.TypeSpec) //nolint:forcetypeassert + obj, ok := common.GetObjectForNode(pass.TypesInfo, node).Get() + if !ok { + return + } + extractEnumVariant(pass, node, obj) + }) + return common.NewExtractorResult(pass), nil +} + +func extractEnumVariant(pass *analysis.Pass, node *ast.TypeSpec, obj types.Object) { + typ := pass.TypesInfo.TypeOf(node.Type) + if common.IsType[*types.Interface](typ) { + return + } + + variant := &schema.EnumVariant{ + Pos: common.GoPosToSchemaPos(pass.Fset, node.Pos()), + Name: strcase.ToUpperCamel(node.Name.Name), + } + if md, ok := common.GetFactForObject[*common.ExtractedMetadata](pass, obj).Get(); ok { + variant.Comments = md.Comments + } + for o := range common.GetAllFacts[*common.MaybeTypeEnum](pass) { + named, ok := pass.TypesInfo.TypeOf(node.Name).(*types.Named) + if !ok { + continue + } + iType := o.Type().Underlying().(*types.Interface) //nolint:forcetypeassert + if !types.Implements(named, iType) { + continue + } + + // valueFunc is only executed if this potential variant actually makes it to the schema. + // Executing may result in transitive schema extraction, so we only execute if necessary. + valueFunc := func(p *analysis.Pass) optional.Option[*schema.TypeValue] { + value, ok := common.ExtractTypeForNode(p, obj, node.Type, nil).Get() + if !ok { + return optional.None[*schema.TypeValue]() + } + return optional.Some(&schema.TypeValue{ + Pos: common.GoPosToSchemaPos(p.Fset, node.Pos()), + Value: value, + }) + } + common.MarkMaybeTypeEnumVariant(pass, obj, variant, o, valueFunc) + } +} diff --git a/go-runtime/schema/valueenumvariant/analyzer.go b/go-runtime/schema/valueenumvariant/analyzer.go new file mode 100644 index 0000000000..c7f067aaac --- /dev/null +++ b/go-runtime/schema/valueenumvariant/analyzer.go @@ -0,0 +1,131 @@ +package valueenumvariant + +import ( + "go/ast" + "go/token" + "go/types" + "strconv" + + "github.com/TBD54566975/ftl/backend/schema" + "github.com/TBD54566975/ftl/backend/schema/strcase" + "github.com/TBD54566975/ftl/go-runtime/schema/common" + "github.com/TBD54566975/golang-tools/go/analysis" + "github.com/TBD54566975/golang-tools/go/analysis/passes/inspect" + "github.com/TBD54566975/golang-tools/go/ast/inspector" + "github.com/alecthomas/types/optional" +) + +// Extractor extracts possible value enum variants. +// +// All named constants are marked as possible enum variants and subsequent extractors determine if they are part of an +// enum. +var Extractor = common.NewExtractor("valueenumvariant", (*Fact)(nil), Extract) + +type Tag struct{} // Tag uniquely identifies the fact type for this extractor. +type Fact = common.DefaultFact[Tag] + +func Extract(pass *analysis.Pass) (interface{}, error) { + in := pass.ResultOf[inspect.Analyzer].(*inspector.Inspector) //nolint:forcetypeassert + nodeFilter := []ast.Node{ + (*ast.GenDecl)(nil), + } + in.Preorder(nodeFilter, func(n ast.Node) { + node := n.(*ast.GenDecl) //nolint:forcetypeassert + if node.Tok != token.CONST { + return + } + + var typ ast.Expr + for i, s := range node.Specs { + v, ok := s.(*ast.ValueSpec) + if !ok { + continue + } + + // In an iota enum, only the first value has a type. + // Hydrate this to subsequent values so we can associate them with the enum. + if i == 0 && isIotaEnum(v) { + typ = v.Type + } else if v.Type == nil { + v.Type = typ + } + extractEnumVariant(pass, v) + } + }) + return common.NewExtractorResult(pass), nil +} + +func extractEnumVariant(pass *analysis.Pass, node *ast.ValueSpec) { + _, ok := node.Type.(*ast.Ident) + if !ok { + return + } + c, ok := pass.TypesInfo.Defs[node.Names[0]].(*types.Const) + if !ok { + return + } + value, ok := extractValue(pass, c).Get() + if !ok { + return + } + + obj, ok := common.GetObjectForNode(pass.TypesInfo, node).Get() + if !ok { + return + } + variant := &schema.EnumVariant{ + Pos: common.GoPosToSchemaPos(pass.Fset, c.Pos()), + Name: strcase.ToUpperCamel(c.Id()), + Value: value, + } + if md, ok := common.GetFactForObject[*common.ExtractedMetadata](pass, obj).Get(); ok { + variant.Comments = md.Comments + } + common.MarkMaybeValueEnumVariant(pass, obj, variant) +} + +func extractValue(pass *analysis.Pass, cnode *types.Const) optional.Option[schema.Value] { + if b, ok := cnode.Type().Underlying().(*types.Basic); ok { + switch b.Kind() { + case types.String: + value, err := strconv.Unquote(cnode.Val().String()) + if err != nil { + return optional.None[schema.Value]() + } + return optional.Some[schema.Value](&schema.StringValue{ + Pos: common.GoPosToSchemaPos(pass.Fset, cnode.Pos()), + Value: value, + }) + + case types.Int: + value, err := strconv.ParseInt(cnode.Val().String(), 10, 64) + if err != nil { + return optional.None[schema.Value]() + } + return optional.Some[schema.Value](&schema.IntValue{ + Pos: common.GoPosToSchemaPos(pass.Fset, cnode.Pos()), + Value: int(value), + }) + + default: + return optional.None[schema.Value]() + } + } + return optional.None[schema.Value]() +} + +func isIotaEnum(node ast.Node) bool { + switch t := node.(type) { + case *ast.ValueSpec: + if len(t.Values) != 1 { + return false + } + return isIotaEnum(t.Values[0]) + case *ast.Ident: + return t.Name == "iota" + case *ast.BinaryExpr: + return isIotaEnum(t.X) || isIotaEnum(t.Y) + default: + return false + } +}