From aab4834002dfd299006b62333f3fd00123f1fb70 Mon Sep 17 00:00:00 2001 From: Rhodon Date: Thu, 4 May 2023 14:09:11 +0200 Subject: [PATCH 1/2] Fix interface property lookup in generic method --- .../Extensions/TypeExtensions.cs | 58 ++++++++----------- .../InheritedModelTests.cs | 42 ++++++++++++++ 2 files changed, 67 insertions(+), 33 deletions(-) diff --git a/src/EntityFrameworkCore.Projectables/Extensions/TypeExtensions.cs b/src/EntityFrameworkCore.Projectables/Extensions/TypeExtensions.cs index c9e1bef..ae171e7 100644 --- a/src/EntityFrameworkCore.Projectables/Extensions/TypeExtensions.cs +++ b/src/EntityFrameworkCore.Projectables/Extensions/TypeExtensions.cs @@ -41,23 +41,8 @@ private static bool CanHaveOverridingMethod(this Type derivedType, MethodInfo me return true; } - private static int? GetOverridingMethodIndex(this MethodInfo methodInfo, MethodInfo[]? allDerivedMethods) - { - if (allDerivedMethods is { Length: > 0 }) - { - var baseDefinition = methodInfo.GetBaseDefinition(); - for (var i = 0; i < allDerivedMethods.Length; i++) - { - var derivedMethodInfo = allDerivedMethods[i]; - if (derivedMethodInfo.GetBaseDefinition() == baseDefinition) - { - return i; - } - } - } - - return null; - } + private static bool IsOverridingMethodOf(this MethodInfo methodInfo, MethodInfo baseDefinition) + => methodInfo.GetBaseDefinition() == baseDefinition; public static MethodInfo GetOverridingMethod(this Type derivedType, MethodInfo methodInfo) { @@ -68,31 +53,38 @@ public static MethodInfo GetOverridingMethod(this Type derivedType, MethodInfo m var derivedMethods = derivedType.GetMethods(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance); - return methodInfo.GetOverridingMethodIndex(derivedMethods) is { } i - ? derivedMethods[i] - // No derived methods were found. Return the original methodInfo - : methodInfo; + MethodInfo? overridingMethod = null; + if (derivedMethods is { Length: > 0 }) + { + var baseDefinition = methodInfo.GetBaseDefinition(); + overridingMethod = derivedMethods.FirstOrDefault(derivedMethodInfo + => derivedMethodInfo.IsOverridingMethodOf(baseDefinition)); + } + + return overridingMethod ?? methodInfo; // If no derived methods were found, return the original methodInfo } public static PropertyInfo GetOverridingProperty(this Type derivedType, PropertyInfo propertyInfo) { - var accessor = propertyInfo.GetAccessors()[0]; - - if (!derivedType.CanHaveOverridingMethod(accessor)) + var accessor = propertyInfo.GetAccessors(true).FirstOrDefault(derivedType.CanHaveOverridingMethod); + if (accessor is null) { return propertyInfo; } + + var isGetAccessor = propertyInfo.GetMethod == accessor; var derivedProperties = derivedType.GetProperties(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance); - var derivedPropertyMethods = derivedProperties - .Select((Func) - (propertyInfo.GetMethod == accessor ? p => p.GetMethod : p => p.SetMethod)) - .OfType().ToArray(); - - return accessor.GetOverridingMethodIndex(derivedPropertyMethods) is { } i - ? derivedProperties[i] - // No derived methods were found. Return the original methodInfo - : propertyInfo; + + PropertyInfo? overridingProperty = null; + if (derivedProperties is { Length: > 0 }) + { + var baseDefinition = accessor.GetBaseDefinition(); + overridingProperty = derivedProperties.FirstOrDefault(p + => (isGetAccessor ? p.GetMethod : p.SetMethod)?.IsOverridingMethodOf(baseDefinition) == true); + } + + return overridingProperty ?? propertyInfo; // If no derived methods were found, return the original methodInfo } public static MethodInfo GetImplementingMethod(this Type derivedType, MethodInfo methodInfo) diff --git a/tests/EntityFrameworkCore.Projectables.FunctionalTests/InheritedModelTests.cs b/tests/EntityFrameworkCore.Projectables.FunctionalTests/InheritedModelTests.cs index cea0103..d2a7f34 100644 --- a/tests/EntityFrameworkCore.Projectables.FunctionalTests/InheritedModelTests.cs +++ b/tests/EntityFrameworkCore.Projectables.FunctionalTests/InheritedModelTests.cs @@ -20,8 +20,20 @@ namespace EntityFrameworkCore.Projectables.FunctionalTests [UsesVerify] public class InheritedModelTests { + public interface IBaseProvider + { + ICollection Bases { get; set; } + } + + public class BaseProvider : IBaseProvider + { + public int Id { get; set; } + public ICollection Bases { get; set; } + } + public interface IBase { + int Id { get; } int ComputedProperty { get; } int ComputedMethod(); } @@ -117,6 +129,26 @@ public Task ProjectOverImplementedMethod() return Verifier.Verify(query.ToQueryString()); } + + [Fact] + public Task ProjectOverProvider() + { + using var dbContext = new SampleDbContext(); + + var query = dbContext.Set().AllBases(); + + return Verifier.Verify(query.ToQueryString()); + } + + [Fact] + public Task ProjectOverExtensionMethod() + { + using var dbContext = new SampleDbContext(); + + var query = dbContext.Set().Select(c => c.ComputedPropertyPlusMethod()); + + return Verifier.Verify(query.ToQueryString()); + } } public static class ModelExtensions @@ -128,5 +160,15 @@ public static IQueryable SelectComputedProperty(this IQueryable< public static IQueryable SelectComputedMethod(this IQueryable concretes) where TConcrete : InheritedModelTests.IBase => concretes.Select(x => x.ComputedMethod()); + + public static IQueryable AllBases(this IQueryable concretes) + where TProvider : InheritedModelTests.IBaseProvider + where TBase : InheritedModelTests.IBase + => concretes.SelectMany(x => x.Bases).Select(x => x.Id); + + [Projectable] + public static int ComputedPropertyPlusMethod(this TConcrete concrete) + where TConcrete : InheritedModelTests.IBase + => concrete.ComputedProperty + concrete.ComputedMethod(); } } From f806c1cd1f147976a0a2ff384f7c15113d1730ba Mon Sep 17 00:00:00 2001 From: Rhodon Date: Thu, 4 May 2023 14:24:04 +0200 Subject: [PATCH 2/2] Add test verification files --- .../Infrastructure/Internal/ProjectionOptionsExtension.cs | 4 ++-- ...nheritedModelTests.ProjectOverExtensionMethod.verified.txt | 2 ++ .../InheritedModelTests.ProjectOverProvider.verified.txt | 3 +++ 3 files changed, 7 insertions(+), 2 deletions(-) create mode 100644 tests/EntityFrameworkCore.Projectables.FunctionalTests/InheritedModelTests.ProjectOverExtensionMethod.verified.txt create mode 100644 tests/EntityFrameworkCore.Projectables.FunctionalTests/InheritedModelTests.ProjectOverProvider.verified.txt diff --git a/src/EntityFrameworkCore.Projectables/Infrastructure/Internal/ProjectionOptionsExtension.cs b/src/EntityFrameworkCore.Projectables/Infrastructure/Internal/ProjectionOptionsExtension.cs index b22e570..3cb2af7 100644 --- a/src/EntityFrameworkCore.Projectables/Infrastructure/Internal/ProjectionOptionsExtension.cs +++ b/src/EntityFrameworkCore.Projectables/Infrastructure/Internal/ProjectionOptionsExtension.cs @@ -54,7 +54,7 @@ static object CreateTargetInstance(IServiceProvider services, ServiceDescriptor var targetDescriptor = services.FirstOrDefault(x => x.ServiceType == typeof(IQueryCompiler)); if (targetDescriptor is null) { - throw new InvalidOperationException("No QueryProvider is configured yet. Please make sure to configure a database provider first"); ; + throw new InvalidOperationException("No QueryProvider is configured yet. Please make sure to configure a database provider first"); } var decoratorObjectFactory = ActivatorUtilities.CreateFactory(typeof(CustomQueryCompiler), new[] { targetDescriptor.ServiceType }); @@ -70,7 +70,7 @@ static object CreateTargetInstance(IServiceProvider services, ServiceDescriptor var targetDescriptor = services.FirstOrDefault(x => x.ServiceType == typeof(IQueryTranslationPreprocessorFactory)); if (targetDescriptor is null) { - throw new InvalidOperationException("No QueryTranslationPreprocessorFactory is configured yet. Please make sure to configure a database provider first"); ; + throw new InvalidOperationException("No QueryTranslationPreprocessorFactory is configured yet. Please make sure to configure a database provider first"); } var decoratorObjectFactory = ActivatorUtilities.CreateFactory(typeof(CustomQueryTranslationPreprocessorFactory), new[] { targetDescriptor.ServiceType }); diff --git a/tests/EntityFrameworkCore.Projectables.FunctionalTests/InheritedModelTests.ProjectOverExtensionMethod.verified.txt b/tests/EntityFrameworkCore.Projectables.FunctionalTests/InheritedModelTests.ProjectOverExtensionMethod.verified.txt new file mode 100644 index 0000000..d7fdac3 --- /dev/null +++ b/tests/EntityFrameworkCore.Projectables.FunctionalTests/InheritedModelTests.ProjectOverExtensionMethod.verified.txt @@ -0,0 +1,2 @@ +SELECT 4 +FROM [Concrete] AS [c] \ No newline at end of file diff --git a/tests/EntityFrameworkCore.Projectables.FunctionalTests/InheritedModelTests.ProjectOverProvider.verified.txt b/tests/EntityFrameworkCore.Projectables.FunctionalTests/InheritedModelTests.ProjectOverProvider.verified.txt new file mode 100644 index 0000000..a25407f --- /dev/null +++ b/tests/EntityFrameworkCore.Projectables.FunctionalTests/InheritedModelTests.ProjectOverProvider.verified.txt @@ -0,0 +1,3 @@ +SELECT [c].[Id] +FROM [BaseProvider] AS [b] +INNER JOIN [Concrete] AS [c] ON [b].[Id] = [c].[BaseProviderId] \ No newline at end of file