diff --git a/backend/schema/normalise.go b/backend/schema/normalise.go index 8cecd85cfa..af8d3245c1 100644 --- a/backend/schema/normalise.go +++ b/backend/schema/normalise.go @@ -4,146 +4,35 @@ import "github.com/TBD54566975/ftl/internal/reflect" // Normalise clones and normalises (zeroes) positional information in schema Nodes. func Normalise[T Node](n T) T { - n = reflect.DeepCopy(n) - var zero Position - var ni Node = n - switch c := ni.(type) { - case *Any: - c.Any = false - c.Pos = zero + ni := reflect.DeepCopy(n) + _ = Visit(ni, func(n Node, next func() error) error { + switch n := n.(type) { + case *Bool: + n.Bool = false - case *TypeParameter: - c.Pos = zero + case *Float: + n.Float = false - case *Unit: - c.Unit = true - c.Pos = zero + case *Int: + n.Int = false - case *Schema: - c.Pos = zero - c.Modules = normaliseSlice(c.Modules) + case *String: + n.Str = false - case *Module: - c.Pos = zero - c.Decls = normaliseSlice(c.Decls) + case *Any: + n.Any = false - case *Array: - c.Pos = zero - c.Element = Normalise(c.Element) + case *Unit: + n.Unit = true - case *Bool: - c.Bool = false - c.Pos = zero + case *Time: + n.Time = false - case *Data: - c.Pos = zero - c.TypeParameters = normaliseSlice(c.TypeParameters) - c.Fields = normaliseSlice(c.Fields) - c.Metadata = normaliseSlice(c.Metadata) - - case *Database: - c.Pos = zero - - case *Ref: - c.TypeParameters = normaliseSlice(c.TypeParameters) - c.Pos = zero - - case *Enum: - c.Pos = zero - if c.Type != nil { - c.Type = Normalise(c.Type) + default: // Normally we don't default for sum types, but this is just for tests and will be immediately obvious. } - c.Variants = normaliseSlice(c.Variants) - - case *EnumVariant: - c.Pos = zero - c.Value = Normalise(c.Value) - - case *TypeValue: - c.Pos = zero - c.Value = Normalise(c.Value) - - case *Field: - c.Pos = zero - c.Type = Normalise(c.Type) - c.Metadata = normaliseSlice(c.Metadata) - - case *Float: - c.Float = false - c.Pos = zero - - case *Int: - c.Int = false - c.Pos = zero - - case *IntValue: - c.Pos = zero - - case *Time: - c.Time = false - c.Pos = zero - - case *Map: - c.Pos = zero - c.Key = Normalise(c.Key) - c.Value = Normalise(c.Value) - - case *String: - c.Str = false - c.Pos = zero - - case *StringValue: - c.Pos = zero - - case *Bytes: - c.Bytes = false - c.Pos = zero - - case *Verb: - c.Pos = zero - c.Request = Normalise(c.Request) - c.Response = Normalise(c.Response) - c.Metadata = normaliseSlice(c.Metadata) - - case *MetadataCalls: - c.Pos = zero - c.Calls = normaliseSlice(c.Calls) - - case *MetadataDatabases: - c.Pos = zero - c.Calls = normaliseSlice(c.Calls) - - case *MetadataIngress: - c.Pos = zero - c.Path = normaliseSlice(c.Path) - - case *MetadataAlias: - c.Pos = zero - - case *Optional: - c.Type = Normalise(c.Type) - - case *IngressPathLiteral: - c.Pos = zero - - case *IngressPathParameter: - c.Pos = zero - - case *MetadataCronJob: - c.Pos = zero - - case *Config: - c.Pos = zero - c.Type = Normalise(c.Type) - - case *Secret: - c.Pos = zero - c.Type = Normalise(c.Type) - - case Named, Symbol, Decl, Metadata, IngressPathComponent, Type, Value: // Can never occur in reality, but here to satisfy the sum-type check. - panic("??") - } - return ni.(T) //nolint:forcetypeassert + return next() + }) + return ni //nolint:forcetypeassert } func normaliseSlice[T Node](in []T) []T { diff --git a/backend/schema/schema_test.go b/backend/schema/schema_test.go index 729f19d30b..eb8aa40498 100644 --- a/backend/schema/schema_test.go +++ b/backend/schema/schema_test.go @@ -177,7 +177,7 @@ func TestParserRoundTrip(t *testing.T) { assert.NoError(t, err, "%s", testSchema.String()) actual, err = ValidateSchema(actual) assert.NoError(t, err) - assert.Equal(t, Normalise(testSchema), Normalise(actual)) + assert.Equal(t, Normalise(testSchema), Normalise(actual), assert.Exclude[Position]()) } func TestParsing(t *testing.T) { @@ -380,7 +380,7 @@ func TestParsing(t *testing.T) { assert.NotZero(t, test.expected, "test.expected is nil") assert.NotZero(t, test.expected.Modules, "test.expected.Modules is nil") test.expected.Modules = append([]*Module{Builtins()}, test.expected.Modules...) - assert.Equal(t, Normalise(test.expected), Normalise(actual), assert.OmitEmpty()) + assert.Equal(t, Normalise(test.expected), Normalise(actual), assert.OmitEmpty(), assert.Exclude[Position]()) } }) } @@ -431,19 +431,19 @@ func TestParseEnum(t *testing.T) { Blue = "Blue" Green = "Green" } - + export enum ColorInt: Int { Red = 0 Blue = 1 Green = 2 } - + enum TypeEnum { A String B [String] C Int } - + enum StringTypeEnum { A String B String @@ -456,7 +456,7 @@ func TestParseEnum(t *testing.T) { actual, err := ParseModuleString("", input) assert.NoError(t, err) actual = Normalise(actual) - assert.Equal(t, Normalise(testSchema.Modules[2]), actual) + assert.Equal(t, Normalise(testSchema.Modules[2]), actual, assert.Exclude[Position]()) } var testSchema = MustValidate(&Schema{