From f69ecde5ee92bf1b0e8ccf7ed60a07af9a00af50 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Doeraene?= Date: Mon, 18 Dec 2023 15:48:04 +0100 Subject: [PATCH] Fix #423: Merge the TypeBounds of type members without needing subtyping. There was an infinite recursion between looking up a type member of a refinement and subtyping of that member against the same refinement. This came from computing the merged TypeBounds of the type member during subtyping, which used subtyping to get rid of useless bounds. We break the cycle by not using subtyping when merging the TypeBounds of a type member anymore. Instead, we manually dive into possibly-higher-kinded bounds (`TypeLambda`s themselves, but also `Nothing` and `AnyKind`), and otherwise construction a union or intersection type. Unwrapping higher-kinded bounds is necessary because constructing a uniont or intersection requires proper types. --- .../src/main/scala/tastyquery/Types.scala | 50 ++++++++++++++++++- 1 file changed, 49 insertions(+), 1 deletion(-) diff --git a/tasty-query/shared/src/main/scala/tastyquery/Types.scala b/tasty-query/shared/src/main/scala/tastyquery/Types.scala index a9459a06..6ab6c446 100644 --- a/tasty-query/shared/src/main/scala/tastyquery/Types.scala +++ b/tasty-query/shared/src/main/scala/tastyquery/Types.scala @@ -198,7 +198,7 @@ object Types { ResolveMemberResult.TypeMember(rightSyms, rightBounds) ) => val syms = mergeSyms(leftSyms, rightSyms) - val bounds = leftBounds.intersect(rightBounds) + val bounds = mergeTypeMemberTypeBounds(leftBounds, rightBounds) ResolveMemberResult.TypeMember(syms, bounds) // Cases that cannot happen -- list them to preserve exhaustivity checking of every other case @@ -228,6 +228,54 @@ object Types { case _ => throw InvalidProgramStructureException(s"Cannot merge types $tp1 and $tp2") end mergeTermMemberTypes + + private def mergeTypeMemberTypeBounds(bounds1: TypeBounds, bounds2: TypeBounds)(using Context): TypeBounds = + // This implementation assumes that the program structure is valid + (bounds1, bounds2) match + case _ if bounds1 eq bounds2 => + bounds1 + case (bounds1: TypeAlias, _) => + bounds1 + case (_, bounds2: TypeAlias) => + bounds2 + + case (bounds1 @ AbstractTypeBounds(low1, high1), bounds2 @ AbstractTypeBounds(low2, high2)) => + val mergedLow = mergeTypeMemberLowBounds(low1, low2) + val mergedHigh = mergeTypeMemberHighBounds(high1, high2) + bounds1.derivedTypeBounds(mergedLow, mergedHigh) + end mergeTypeMemberTypeBounds + + private def mergeTypeMemberLowBounds(low1: Type, low2: Type)(using Context): Type = + (low1.dealias, low2.dealias) match + case (low1: TypeLambda, low2: TypeLambda) if low1.paramNames.sizeCompare(low2.paramNames) == 0 => + low1.derivedLambdaType( + low1.paramNames, + low1.paramTypeBounds, + mergeTypeMemberLowBounds(low1.resultType, low2.instantiate(low1.paramRefs)) + ) + case (_: NothingType, _) | (_, _: AnyKindType) => + low2 + case (_, _: NothingType) | (_: AnyKindType, _) => + low1 + case _ => + low1 | low2 + end mergeTypeMemberLowBounds + + private def mergeTypeMemberHighBounds(high1: Type, high2: Type)(using Context): Type = + (high1.dealias, high2.dealias) match + case (high1: TypeLambda, high2: TypeLambda) if high1.paramNames.sizeCompare(high2.paramNames) == 0 => + high1.derivedLambdaType( + high1.paramNames, + high1.paramTypeBounds, + mergeTypeMemberHighBounds(high1.resultType, high2.instantiate(high1.paramRefs)) + ) + case (_: AnyKindType, _) | (_, _: NothingType) => + high2 + case (_, _: AnyKindType) | (_: NothingType, _) => + high1 + case _ => + high1 & high2 + end mergeTypeMemberHighBounds end ResolveMemberResult /** A type parameter of a type constructor.