diff --git a/go-runtime/compile/build.go b/go-runtime/compile/build.go index 4c3d0bef73..eb426057a3 100644 --- a/go-runtime/compile/build.go +++ b/go-runtime/compile/build.go @@ -59,13 +59,12 @@ type mainModuleContext struct { TypesCtx typesFileContext } -func (c mainModuleContext) withImports(mainModuleImport string) mainModuleContext { +func (c *mainModuleContext) withImports(mainModuleImport string) { c.MainCtx.Imports = c.generateMainImports() c.TypesCtx.Imports = c.generateTypesImports(mainModuleImport) - return c } -func (c mainModuleContext) generateMainImports() []string { +func (c *mainModuleContext) generateMainImports() []string { imports := sets.NewSet[string]() imports.Add(`"context"`) imports.Add(`"github.com/TBD54566975/ftl/backend/protos/xyz/block/ftl/v1/ftlv1connect"`) @@ -90,7 +89,7 @@ func (c mainModuleContext) generateMainImports() []string { return imports.ToSlice() } -func (c mainModuleContext) generateTypesImports(mainModuleImport string) []string { +func (c *mainModuleContext) generateTypesImports(mainModuleImport string) []string { imports := sets.NewSet[string]() if len(c.TypesCtx.SumTypes) > 0 || len(c.TypesCtx.ExternalTypes) > 0 { imports.Add(`"github.com/TBD54566975/ftl/go-runtime/ftl/reflection"`) @@ -563,7 +562,7 @@ func buildMainModuleContext(sch *schema.Schema, result extract.Result, goModVers func (b *mainModuleContextBuilder) build(goModVersion, ftlVersion, projectName string, sharedModulesPaths []string, replacements []*modfile.Replace) (mainModuleContext, error) { - ctx := mainModuleContext{ + ctx := &mainModuleContext{ GoVersion: goModVersion, FTLVersion: ftlVersion, Name: b.mainModule.Name, @@ -582,8 +581,53 @@ func (b *mainModuleContextBuilder) build(goModVersion, ftlVersion, projectName s } visited := sets.NewSet[string]() - err := schema.Visit(b.mainModule, func(node schema.Node, next func() error) error { - maybeGoType, isLocal, err := b.getGoType(b.mainModule, node) + err := b.visit(ctx, b.mainModule, b.mainModule, visited) + if err != nil { + return mainModuleContext{}, err + } + + slices.SortFunc(ctx.MainCtx.SumTypes, func(a, b goSumType) int { + return strings.Compare(a.TypeName(), b.TypeName()) + }) + slices.SortFunc(ctx.TypesCtx.SumTypes, func(a, b goSumType) int { + return strings.Compare(a.TypeName(), b.TypeName()) + }) + + ctx.TypesCtx.MainModulePkg = b.mainModule.Name + mainModuleImport := fmt.Sprintf("ftl/%s", b.mainModule.Name) + if alias, ok := b.imports[mainModuleImport]; ok { + mainModuleImport = fmt.Sprintf("%s %q", alias, mainModuleImport) + ctx.TypesCtx.MainModulePkg = alias + } + ctx.withImports(mainModuleImport) + return *ctx, nil +} + +func (b *mainModuleContextBuilder) visit( + ctx *mainModuleContext, + module *schema.Module, + node schema.Node, + visited sets.Set[string], +) error { + err := schema.Visit(node, func(node schema.Node, next func() error) error { + if ref, ok := node.(*schema.Ref); ok { + maybeResolved, maybeModule := b.sch.ResolveWithModule(ref) + resolved, ok := maybeResolved.Get() + if !ok { + return next() + } + m, ok := maybeModule.Get() + if !ok { + return next() + } + err := b.visit(ctx, m, resolved, visited) + if err != nil { + return fmt.Errorf("failed to visit children of %s: %w", ref, err) + } + return next() + } + + maybeGoType, isLocal, err := b.getGoType(module, node) if err != nil { return err } @@ -611,40 +655,14 @@ func (b *mainModuleContextBuilder) build(goModVersion, ftlVersion, projectName s return next() }) if err != nil { - return mainModuleContext{}, fmt.Errorf("failed to build main module context: %w", err) - } - - slices.SortFunc(ctx.MainCtx.SumTypes, func(a, b goSumType) int { - return strings.Compare(a.TypeName(), b.TypeName()) - }) - slices.SortFunc(ctx.TypesCtx.SumTypes, func(a, b goSumType) int { - return strings.Compare(a.TypeName(), b.TypeName()) - }) - - ctx.TypesCtx.MainModulePkg = b.mainModule.Name - mainModuleImport := fmt.Sprintf("ftl/%s", b.mainModule.Name) - if alias, ok := b.imports[mainModuleImport]; ok { - mainModuleImport = fmt.Sprintf("%s %q", alias, mainModuleImport) - ctx.TypesCtx.MainModulePkg = alias + return fmt.Errorf("failed to build main module context: %w", err) } - return ctx.withImports(mainModuleImport), nil + return nil } func (b *mainModuleContextBuilder) getGoType(module *schema.Module, node schema.Node) (gotype optional.Option[goType], isLocal bool, err error) { isLocal = b.visitingMainModule(module.Name) switch n := node.(type) { - case *schema.Ref: - maybeResolved, maybeModule := b.sch.ResolveWithModule(n) - resolved, ok := maybeResolved.Get() - if !ok { - return optional.None[goType](), isLocal, nil - } - m, ok := maybeModule.Get() - if !ok { - return optional.None[goType](), isLocal, nil - } - return b.getGoType(m, resolved) - case *schema.Verb: if !isLocal { return optional.None[goType](), false, nil diff --git a/internal/buildengine/testdata/alpha/types.ftl.go b/internal/buildengine/testdata/alpha/types.ftl.go index 8a98b8c467..93ec7a961f 100644 --- a/internal/buildengine/testdata/alpha/types.ftl.go +++ b/internal/buildengine/testdata/alpha/types.ftl.go @@ -5,6 +5,7 @@ import ( "context" ftlother "ftl/other" "github.com/TBD54566975/ftl/go-runtime/ftl/reflection" + lib "github.com/TBD54566975/ftl/go-runtime/schema/testdata" "github.com/TBD54566975/ftl/go-runtime/server" ) @@ -12,6 +13,7 @@ type EchoClient func(context.Context, EchoRequest) (EchoResponse, error) func init() { reflection.Register( + reflection.ExternalType(*new(lib.AnotherNonFTLType)), reflection.ProvideResourcesForVerb( Echo, server.VerbClient[ftlother.EchoClient, ftlother.EchoRequest, ftlother.EchoResponse](), diff --git a/internal/buildengine/testdata/type_registry_main.go b/internal/buildengine/testdata/type_registry_main.go index 2283fb995e..d14b8fd770 100644 --- a/internal/buildengine/testdata/type_registry_main.go +++ b/internal/buildengine/testdata/type_registry_main.go @@ -14,6 +14,10 @@ import ( func init() { reflection.Register( + reflection.SumType[ftlanother.SecondTypeEnum]( + *new(ftlanother.One), + *new(ftlanother.Two), + ), reflection.SumType[ftlanother.TypeEnum]( *new(ftlanother.A), *new(ftlanother.B),