diff --git a/migrations/statictypes/statictype_migration.go b/migrations/statictypes/statictype_migration.go index 51f2553cf3..cc3ac2a9bf 100644 --- a/migrations/statictypes/statictype_migration.go +++ b/migrations/statictypes/statictype_migration.go @@ -247,16 +247,31 @@ func (m *StaticTypeMigration) maybeConvertStaticType(staticType, parentType inte legacyType := rewrittenIntersectionType.LegacyType + var mergedIntersections bool + var convertedLegacyType interpreter.StaticType if legacyType != nil { convertedLegacyType = m.maybeConvertStaticType(legacyType, rewrittenIntersectionType) - switch convertedLegacyType.(type) { + switch ty := convertedLegacyType.(type) { case nil, *interpreter.CompositeStaticType, interpreter.PrimitiveStaticType: // valid break + case *interpreter.IntersectionStaticType: + // If the legacy type was converted to an intersection type, + // then merge it into the resulting intersection type + + legacyType = nil + convertedLegacyType = nil + + convertedInterfaceTypes = append( + convertedInterfaceTypes, + ty.Types..., + ) + mergedIntersections = true + default: panic(fmt.Errorf( "invalid non-composite/primitive replacement for legacy type in intersection type %s:"+ @@ -273,13 +288,16 @@ func (m *StaticTypeMigration) maybeConvertStaticType(staticType, parentType inte // If the interface set has at least two items, // then force it to be re-stored/re-encoded, // even if the interface types in the set have not changed. - if len(rewrittenIntersectionType.Types) >= 2 || convertedInterfaceType || convertedLegacyType != nil { + if len(rewrittenIntersectionType.Types) >= 2 || + convertedInterfaceType || + convertedLegacyType != nil || + mergedIntersections { result := interpreter.NewIntersectionStaticType(nil, convertedInterfaceTypes) if convertedLegacyType != nil { result.LegacyType = convertedLegacyType - } else { + } else if legacyType != nil { result.LegacyType = legacyType } diff --git a/migrations/statictypes/statictype_migration_test.go b/migrations/statictypes/statictype_migration_test.go index c9cecf7298..fd9b881550 100644 --- a/migrations/statictypes/statictype_migration_test.go +++ b/migrations/statictypes/statictype_migration_test.go @@ -604,6 +604,93 @@ func TestStaticTypeMigration(t *testing.T) { }) }) + + t.Run("merge converted legacy type when intersection", func(t *testing.T) { + + t.Parallel() + + const compositeQualifiedIdentifier = "S" + compositeType := interpreter.NewCompositeStaticType( + nil, + utils.TestLocation, + compositeQualifiedIdentifier, + utils.TestLocation.TypeID(nil, compositeQualifiedIdentifier), + ) + + const interface1QualifiedIdentifier = "SI1" + interfaceType1 := interpreter.NewInterfaceStaticType( + nil, + utils.TestLocation, + interface1QualifiedIdentifier, + utils.TestLocation.TypeID(nil, interface1QualifiedIdentifier), + ) + + const interface2QualifiedIdentifier = "SI2" + interfaceType2 := interpreter.NewInterfaceStaticType( + nil, + utils.TestLocation, + interface2QualifiedIdentifier, + utils.TestLocation.TypeID(nil, interface2QualifiedIdentifier), + ) + + intersectionType := interpreter.NewIntersectionStaticType( + nil, + []*interpreter.InterfaceStaticType{ + interfaceType1, + }, + ) + // NOTE: the legacy type is a composite type, + // but it will get rewritten to an intersection type + + intersectionType.LegacyType = compositeType + + staticTypeMigration := NewStaticTypeMigration().WithCompositeTypeConverter( + func(staticType *interpreter.CompositeStaticType) interpreter.StaticType { + if staticType.TypeID != compositeType.TypeID { + return nil + } + + return interpreter.NewIntersectionStaticType( + nil, + []*interpreter.InterfaceStaticType{ + interfaceType2, + }, + ) + }, + ) + + storedValue := interpreter.NewTypeValue( + nil, + interpreter.NewReferenceStaticType( + nil, + interpreter.UnauthorizedAccess, + intersectionType, + ), + ) + + actual := migrate(t, + staticTypeMigration, + storedValue, + true, + ) + + expected := interpreter.NewTypeValue( + nil, + interpreter.NewReferenceStaticType( + nil, + interpreter.UnauthorizedAccess, + interpreter.NewIntersectionStaticType( + nil, + []*interpreter.InterfaceStaticType{ + interfaceType1, + interfaceType2, + }, + ), + ), + ) + + assert.Equal(t, expected, actual) + }) } func TestMigratingNestedContainers(t *testing.T) {