Skip to content

Commit

Permalink
#54 Add support for devirtualizing explicit interface implementations
Browse files Browse the repository at this point in the history
  • Loading branch information
Miista committed Jan 16, 2025
1 parent 6590716 commit 8c9b23f
Show file tree
Hide file tree
Showing 5 changed files with 368 additions and 5 deletions.
13 changes: 13 additions & 0 deletions src/Pose/Helpers/MethodHelper.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
using System;
using System.Reflection;

namespace Pose.Helpers
{
internal static class MethodHelper
{
public static MethodBase GetMethodFromHandle(RuntimeMethodHandle handle, RuntimeTypeHandle declaringType)
{
return MethodBase.GetMethodFromHandle(handle, declaringType);
}
}
}
78 changes: 75 additions & 3 deletions src/Pose/Helpers/StubHelper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,28 @@ public static int GetIndexOfMatchingShim(MethodBase methodBase, Type type, objec
s => ReferenceEquals(obj, s.Instance) && s.Original == methodBase);

if (index == -1)
return Array.FindIndex(PoseContext.Shims,
s => SignatureEquals(s, type, methodBase) && s.Instance == null);
{
index = Array.FindIndex(PoseContext.Shims, s => SignatureEquals(s, type, methodBase) && s.Instance == null);
}

if (index == -1)
{
index = Array.FindIndex(PoseContext.Shims, s => IsExplicitImplementation(s, type, methodBase) && s.Instance == null);
}

return index;
}

private static bool IsExplicitImplementation(Shim shim, Type type, MethodBase method)
{
if (shim.Type == null || type == shim.Type) return false;
if (!shim.Type.IsInterface) return false;
if (!type.ImplementsInterface(shim.Type)) return false;

var interfaceMap = type.GetInterfaceMap(shim.Type);

return interfaceMap.TargetMethods.FirstOrDefault(m => m.Name == method.Name) != null;
}

public static int GetIndexOfMatchingShim(MethodBase methodBase, object obj)
=> GetIndexOfMatchingShim(methodBase, methodBase.DeclaringType, obj);
Expand All @@ -59,9 +76,64 @@ public static MethodInfo DeVirtualizeMethod(Type thisType, MethodInfo virtualMet
var bindingFlags = BindingFlags.Instance | (virtualMethod.IsPublic ? BindingFlags.Public : BindingFlags.NonPublic);
var types = virtualMethod.GetParameters().Select(p => p.ParameterType).ToArray();

return thisType.GetMethod(virtualMethod.Name, bindingFlags, null, types, null);
var method = thisType.GetMethod(virtualMethod.Name, bindingFlags, null, types, null);

if (method == null)
{
// Attempt to get the method from an explicitly implemented interface
var interfaces = thisType.GetInterfaces();
var declaringInterface = interfaces.FirstOrDefault(i => i == virtualMethod.DeclaringType);
var explicitlyImplementedMethod = GetExplicitlyImplementedMethod(declaringInterface, thisType, virtualMethod);

return explicitlyImplementedMethod;
}

return method;
}

private static MethodInfo GetExplicitlyImplementedMethod(Type interfaceType, Type type, MethodInfo virtualMethod)
{
if (type == null) throw new ArgumentNullException(nameof(type));
if (virtualMethod == null) throw new ArgumentNullException(nameof(virtualMethod));
if (!interfaceType.IsInterface) throw new InvalidOperationException($"{interfaceType} is not an interface.");

var method = interfaceType.GetMethod(virtualMethod.Name) ?? throw new Exception();

var methodDeclaringType = method.DeclaringType ?? throw new Exception($"The {virtualMethod} method does not have a declaring type");

if (type.IsArray && methodDeclaringType.IsGenericType)
{
// Cannot retrieve interface maps for generic interfaces on arrays
return type
.GetMethods(BindingFlags.Instance | BindingFlags.NonPublic)
.SingleOrDefault(m => m.Name.EndsWith(virtualMethod.Name));
}

// Retrieve the method via the interface mapping for the type
var interfaceMapping = type.GetInterfaceMap(methodDeclaringType);
var interfaceMethods = interfaceMapping.InterfaceMethods.ToArray();
var index = Array.FindIndex(interfaceMethods, m => m.DeclaringType == interfaceType && m.Name == virtualMethod.Name);
var targetMethod = interfaceMapping.TargetMethods[index];

return targetMethod;
}

private static Type GetInterfaceType(this Type type, Type interfaceType)
{
if (type == null) throw new ArgumentNullException(nameof(type));
if (!interfaceType.IsInterface) throw new InvalidOperationException($"{interfaceType} is not an interface.");

return type.GetInterfaces().FirstOrDefault(interfaceType1 => interfaceType1 == interfaceType);
}

private static bool ImplementsInterface(this Type type, Type interfaceType)
{
if (type == null) throw new ArgumentNullException(nameof(type));
if (!interfaceType.IsInterface) throw new InvalidOperationException($"{interfaceType} is not an interface.");

return type.GetInterfaceType(interfaceType) != null;
}

public static Module GetOwningModule() => typeof(StubHelper).Module;

public static bool IsIntrinsic(MethodBase method)
Expand Down
4 changes: 2 additions & 2 deletions src/Pose/IL/Stubs.cs
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ internal static class Stubs
?? throw new Exception($"Cannot get method {nameof(MethodRewriter.CreateRewriter)} from type {nameof(MethodRewriter)}");

private static readonly MethodInfo GetMethodFromHandle =
typeof(MethodBase).GetMethod(nameof(MethodBase.GetMethodFromHandle), new Type[] { typeof(RuntimeMethodHandle), typeof(RuntimeTypeHandle) })
?? throw new Exception($"Cannot get method {nameof(MethodBase.GetMethodFromHandle)} from type {nameof(MethodBase)}");
typeof(MethodHelper).GetMethod(nameof(MethodHelper.GetMethodFromHandle), new Type[] { typeof(RuntimeMethodHandle), typeof(RuntimeTypeHandle) })
?? throw new Exception($"Cannot get method {nameof(MethodHelper.GetMethodFromHandle)} from type {nameof(MethodBase)}");

private static readonly MethodInfo GetIndexOfMatchingShim =
typeof(StubHelper).GetMethod(nameof(StubHelper.GetIndexOfMatchingShim), new []{typeof(MethodBase), typeof(object)})
Expand Down
41 changes: 41 additions & 0 deletions src/Sandbox/Program.cs
Original file line number Diff line number Diff line change
@@ -1,9 +1,19 @@
// See https://aka.ms/new-console-template for more information

using System;
using System.Collections.Generic;
using System.Linq;

namespace Pose.Sandbox
{
public static class TClass
{
public static int Get(this List<int> list)
{
return 0;
}
}

public class Program
{
static void Constrain<TT>(TT a) where TT : IA{
Expand Down Expand Up @@ -53,8 +63,39 @@ public class OverridenOperatorClass
public static OverridenOperatorClass operator +(OverridenOperatorClass l, OverridenOperatorClass r) => default(OverridenOperatorClass);
}

public static IQueryable<int> GetInts()
{
return new List<int>().AsQueryable();
}

public static void Main(string[] args)
{
var countShim = Shim
.Replace(() => Is.A<int[]>().Count())
.With((IEnumerable<int> ts) => 0);

var getIntsShim = Shim
.Replace(() => Program.GetInts())
.With(() => new List<int> { 1 }.AsQueryable());

var tt = Shim
.Replace(() => Is.A<List<int>>().Get())
.With((List<int> list) => 20);

PoseContext.Isolate(
() =>
{
var xs = new int[] { 0, 1, 2 };
Console.WriteLine("X: " + xs.Length);
Console.WriteLine("Y: " + xs.Count());

var iis = Program.GetInts();
Console.WriteLine(iis.LongCount());

//Console.WriteLine("X: " + Program.GetInts().Count());
}, getIntsShim, countShim);
return;

#if NET48
Console.WriteLine("4.8");
var dateTimeShim = Shim.Replace(() => DateTime.Now).With(() => new DateTime(2004, 1, 1));
Expand Down
Loading

0 comments on commit 8c9b23f

Please sign in to comment.