Skip to content

Commit

Permalink
Merge pull request #72 from rhodon-jargon/interface-support
Browse files Browse the repository at this point in the history
Fix interface property lookup in generic method
  • Loading branch information
koenbeuk authored May 4, 2023
2 parents d6b1cfb + f806c1c commit b458c71
Show file tree
Hide file tree
Showing 5 changed files with 74 additions and 35 deletions.
58 changes: 25 additions & 33 deletions src/EntityFrameworkCore.Projectables/Extensions/TypeExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -54,23 +54,8 @@ private static bool CanHaveOverridingMethod(this Type derivedType, MethodInfo me
return true;
}

private static int? GetOverridingMethodIndex(this MethodInfo methodInfo, MethodInfo[]? allDerivedMethods)
{
if (allDerivedMethods is { Length: > 0 })
{
var baseDefinition = methodInfo.GetBaseDefinition();
for (var i = 0; i < allDerivedMethods.Length; i++)
{
var derivedMethodInfo = allDerivedMethods[i];
if (derivedMethodInfo.GetBaseDefinition() == baseDefinition)
{
return i;
}
}
}

return null;
}
private static bool IsOverridingMethodOf(this MethodInfo methodInfo, MethodInfo baseDefinition)
=> methodInfo.GetBaseDefinition() == baseDefinition;

public static MethodInfo GetOverridingMethod(this Type derivedType, MethodInfo methodInfo)
{
Expand All @@ -81,31 +66,38 @@ public static MethodInfo GetOverridingMethod(this Type derivedType, MethodInfo m

var derivedMethods = derivedType.GetMethods(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance);

return methodInfo.GetOverridingMethodIndex(derivedMethods) is { } i
? derivedMethods[i]
// No derived methods were found. Return the original methodInfo
: methodInfo;
MethodInfo? overridingMethod = null;
if (derivedMethods is { Length: > 0 })
{
var baseDefinition = methodInfo.GetBaseDefinition();
overridingMethod = derivedMethods.FirstOrDefault(derivedMethodInfo
=> derivedMethodInfo.IsOverridingMethodOf(baseDefinition));
}

return overridingMethod ?? methodInfo; // If no derived methods were found, return the original methodInfo
}

public static PropertyInfo GetOverridingProperty(this Type derivedType, PropertyInfo propertyInfo)
{
var accessor = propertyInfo.GetAccessors(true)[0];

if (!derivedType.CanHaveOverridingMethod(accessor))
var accessor = propertyInfo.GetAccessors(true).FirstOrDefault(derivedType.CanHaveOverridingMethod);
if (accessor is null)
{
return propertyInfo;
}

var isGetAccessor = propertyInfo.GetMethod == accessor;

var derivedProperties = derivedType.GetProperties(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance);
var derivedPropertyMethods = derivedProperties
.Select((Func<PropertyInfo, MethodInfo?>)
(propertyInfo.GetMethod == accessor ? p => p.GetMethod : p => p.SetMethod))
.OfType<MethodInfo>().ToArray();

return accessor.GetOverridingMethodIndex(derivedPropertyMethods) is { } i
? derivedProperties[i]
// No derived methods were found. Return the original methodInfo
: propertyInfo;

PropertyInfo? overridingProperty = null;
if (derivedProperties is { Length: > 0 })
{
var baseDefinition = accessor.GetBaseDefinition();
overridingProperty = derivedProperties.FirstOrDefault(p
=> (isGetAccessor ? p.GetMethod : p.SetMethod)?.IsOverridingMethodOf(baseDefinition) == true);
}

return overridingProperty ?? propertyInfo; // If no derived methods were found, return the original methodInfo
}

public static MethodInfo GetImplementingMethod(this Type derivedType, MethodInfo methodInfo)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ static object CreateTargetInstance(IServiceProvider services, ServiceDescriptor
var targetDescriptor = services.FirstOrDefault(x => x.ServiceType == typeof(IQueryCompiler));
if (targetDescriptor is null)
{
throw new InvalidOperationException("No QueryProvider is configured yet. Please make sure to configure a database provider first"); ;
throw new InvalidOperationException("No QueryProvider is configured yet. Please make sure to configure a database provider first");
}

var decoratorObjectFactory = ActivatorUtilities.CreateFactory(typeof(CustomQueryCompiler), new[] { targetDescriptor.ServiceType });
Expand All @@ -70,7 +70,7 @@ static object CreateTargetInstance(IServiceProvider services, ServiceDescriptor
var targetDescriptor = services.FirstOrDefault(x => x.ServiceType == typeof(IQueryTranslationPreprocessorFactory));
if (targetDescriptor is null)
{
throw new InvalidOperationException("No QueryTranslationPreprocessorFactory is configured yet. Please make sure to configure a database provider first"); ;
throw new InvalidOperationException("No QueryTranslationPreprocessorFactory is configured yet. Please make sure to configure a database provider first");
}

var decoratorObjectFactory = ActivatorUtilities.CreateFactory(typeof(CustomQueryTranslationPreprocessorFactory), new[] { targetDescriptor.ServiceType });
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
SELECT 4
FROM [Concrete] AS [c]
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
SELECT [c].[Id]
FROM [BaseProvider] AS [b]
INNER JOIN [Concrete] AS [c] ON [b].[Id] = [c].[BaseProviderId]
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,20 @@ namespace EntityFrameworkCore.Projectables.FunctionalTests
[UsesVerify]
public class InheritedModelTests
{
public interface IBaseProvider<TBase>
{
ICollection<TBase> Bases { get; set; }
}

public class BaseProvider : IBaseProvider<Concrete>
{
public int Id { get; set; }
public ICollection<Concrete> Bases { get; set; }
}

public interface IBase
{
int Id { get; }
int ComputedProperty { get; }
int ComputedMethod();
}
Expand Down Expand Up @@ -117,6 +129,26 @@ public Task ProjectOverImplementedMethod()

return Verifier.Verify(query.ToQueryString());
}

[Fact]
public Task ProjectOverProvider()
{
using var dbContext = new SampleDbContext<BaseProvider>();

var query = dbContext.Set<BaseProvider>().AllBases<BaseProvider, Concrete>();

return Verifier.Verify(query.ToQueryString());
}

[Fact]
public Task ProjectOverExtensionMethod()
{
using var dbContext = new SampleDbContext<Concrete>();

var query = dbContext.Set<Concrete>().Select(c => c.ComputedPropertyPlusMethod());

return Verifier.Verify(query.ToQueryString());
}
}

public static class ModelExtensions
Expand All @@ -128,5 +160,15 @@ public static IQueryable<int> SelectComputedProperty<TConcrete>(this IQueryable<
public static IQueryable<int> SelectComputedMethod<TConcrete>(this IQueryable<TConcrete> concretes)
where TConcrete : InheritedModelTests.IBase
=> concretes.Select(x => x.ComputedMethod());

public static IQueryable<int> AllBases<TProvider, TBase>(this IQueryable<TProvider> concretes)
where TProvider : InheritedModelTests.IBaseProvider<TBase>
where TBase : InheritedModelTests.IBase
=> concretes.SelectMany(x => x.Bases).Select(x => x.Id);

[Projectable]
public static int ComputedPropertyPlusMethod<TConcrete>(this TConcrete concrete)
where TConcrete : InheritedModelTests.IBase
=> concrete.ComputedProperty + concrete.ComputedMethod();
}
}

0 comments on commit b458c71

Please sign in to comment.