diff --git a/src/Pose/Helpers/MethodHelper.cs b/src/Pose/Helpers/MethodHelper.cs new file mode 100644 index 0000000..98f22cd --- /dev/null +++ b/src/Pose/Helpers/MethodHelper.cs @@ -0,0 +1,13 @@ +using System; +using System.Reflection; + +namespace Pose.Helpers +{ + internal static class MethodHelper + { + public static MethodBase GetMethodFromHandle(RuntimeMethodHandle handle, RuntimeTypeHandle declaringType) + { + return MethodBase.GetMethodFromHandle(handle, declaringType); + } + } +} \ No newline at end of file diff --git a/src/Pose/Helpers/StubHelper.cs b/src/Pose/Helpers/StubHelper.cs index 31cf8cb..719202e 100644 --- a/src/Pose/Helpers/StubHelper.cs +++ b/src/Pose/Helpers/StubHelper.cs @@ -38,11 +38,28 @@ public static int GetIndexOfMatchingShim(MethodBase methodBase, Type type, objec s => ReferenceEquals(obj, s.Instance) && s.Original == methodBase); if (index == -1) - return Array.FindIndex(PoseContext.Shims, - s => SignatureEquals(s, type, methodBase) && s.Instance == null); + { + index = Array.FindIndex(PoseContext.Shims, s => SignatureEquals(s, type, methodBase) && s.Instance == null); + } + + if (index == -1) + { + index = Array.FindIndex(PoseContext.Shims, s => IsExplicitImplementation(s, type, methodBase) && s.Instance == null); + } return index; } + + private static bool IsExplicitImplementation(Shim shim, Type type, MethodBase method) + { + if (shim.Type == null || type == shim.Type) return false; + if (!shim.Type.IsInterface) return false; + if (!type.ImplementsInterface(shim.Type)) return false; + + var interfaceMap = type.GetInterfaceMap(shim.Type); + + return interfaceMap.TargetMethods.FirstOrDefault(m => m.Name == method.Name) != null; + } public static int GetIndexOfMatchingShim(MethodBase methodBase, object obj) => GetIndexOfMatchingShim(methodBase, methodBase.DeclaringType, obj); @@ -59,9 +76,64 @@ public static MethodInfo DeVirtualizeMethod(Type thisType, MethodInfo virtualMet var bindingFlags = BindingFlags.Instance | (virtualMethod.IsPublic ? BindingFlags.Public : BindingFlags.NonPublic); var types = virtualMethod.GetParameters().Select(p => p.ParameterType).ToArray(); - return thisType.GetMethod(virtualMethod.Name, bindingFlags, null, types, null); + var method = thisType.GetMethod(virtualMethod.Name, bindingFlags, null, types, null); + + if (method == null) + { + // Attempt to get the method from an explicitly implemented interface + var interfaces = thisType.GetInterfaces(); + var declaringInterface = interfaces.FirstOrDefault(i => i == virtualMethod.DeclaringType); + var explicitlyImplementedMethod = GetExplicitlyImplementedMethod(declaringInterface, thisType, virtualMethod); + + return explicitlyImplementedMethod; + } + + return method; } + private static MethodInfo GetExplicitlyImplementedMethod(Type interfaceType, Type type, MethodInfo virtualMethod) + { + if (type == null) throw new ArgumentNullException(nameof(type)); + if (virtualMethod == null) throw new ArgumentNullException(nameof(virtualMethod)); + if (!interfaceType.IsInterface) throw new InvalidOperationException($"{interfaceType} is not an interface."); + + var method = interfaceType.GetMethod(virtualMethod.Name) ?? throw new Exception(); + + var methodDeclaringType = method.DeclaringType ?? throw new Exception($"The {virtualMethod} method does not have a declaring type"); + + if (type.IsArray && methodDeclaringType.IsGenericType) + { + // Cannot retrieve interface maps for generic interfaces on arrays + return type + .GetMethods(BindingFlags.Instance | BindingFlags.NonPublic) + .SingleOrDefault(m => m.Name.EndsWith(virtualMethod.Name)); + } + + // Retrieve the method via the interface mapping for the type + var interfaceMapping = type.GetInterfaceMap(methodDeclaringType); + var interfaceMethods = interfaceMapping.InterfaceMethods.ToArray(); + var index = Array.FindIndex(interfaceMethods, m => m.DeclaringType == interfaceType && m.Name == virtualMethod.Name); + var targetMethod = interfaceMapping.TargetMethods[index]; + + return targetMethod; + } + + private static Type GetInterfaceType(this Type type, Type interfaceType) + { + if (type == null) throw new ArgumentNullException(nameof(type)); + if (!interfaceType.IsInterface) throw new InvalidOperationException($"{interfaceType} is not an interface."); + + return type.GetInterfaces().FirstOrDefault(interfaceType1 => interfaceType1 == interfaceType); + } + + private static bool ImplementsInterface(this Type type, Type interfaceType) + { + if (type == null) throw new ArgumentNullException(nameof(type)); + if (!interfaceType.IsInterface) throw new InvalidOperationException($"{interfaceType} is not an interface."); + + return type.GetInterfaceType(interfaceType) != null; + } + public static Module GetOwningModule() => typeof(StubHelper).Module; public static bool IsIntrinsic(MethodBase method) diff --git a/src/Pose/IL/Stubs.cs b/src/Pose/IL/Stubs.cs index 093dc85..5629d05 100644 --- a/src/Pose/IL/Stubs.cs +++ b/src/Pose/IL/Stubs.cs @@ -31,8 +31,8 @@ internal static class Stubs ?? throw new Exception($"Cannot get method {nameof(MethodRewriter.CreateRewriter)} from type {nameof(MethodRewriter)}"); private static readonly MethodInfo GetMethodFromHandle = - typeof(MethodBase).GetMethod(nameof(MethodBase.GetMethodFromHandle), new Type[] { typeof(RuntimeMethodHandle), typeof(RuntimeTypeHandle) }) - ?? throw new Exception($"Cannot get method {nameof(MethodBase.GetMethodFromHandle)} from type {nameof(MethodBase)}"); + typeof(MethodHelper).GetMethod(nameof(MethodHelper.GetMethodFromHandle), new Type[] { typeof(RuntimeMethodHandle), typeof(RuntimeTypeHandle) }) + ?? throw new Exception($"Cannot get method {nameof(MethodHelper.GetMethodFromHandle)} from type {nameof(MethodBase)}"); private static readonly MethodInfo GetIndexOfMatchingShim = typeof(StubHelper).GetMethod(nameof(StubHelper.GetIndexOfMatchingShim), new []{typeof(MethodBase), typeof(object)}) diff --git a/src/Sandbox/Program.cs b/src/Sandbox/Program.cs index b6472f3..2623630 100644 --- a/src/Sandbox/Program.cs +++ b/src/Sandbox/Program.cs @@ -1,9 +1,19 @@ // See https://aka.ms/new-console-template for more information using System; +using System.Collections.Generic; +using System.Linq; namespace Pose.Sandbox { + public static class TClass + { + public static int Get(this List list) + { + return 0; + } + } + public class Program { static void Constrain(TT a) where TT : IA{ @@ -53,8 +63,39 @@ public class OverridenOperatorClass public static OverridenOperatorClass operator +(OverridenOperatorClass l, OverridenOperatorClass r) => default(OverridenOperatorClass); } + public static IQueryable GetInts() + { + return new List().AsQueryable(); + } + public static void Main(string[] args) { + var countShim = Shim + .Replace(() => Is.A().Count()) + .With((IEnumerable ts) => 0); + + var getIntsShim = Shim + .Replace(() => Program.GetInts()) + .With(() => new List { 1 }.AsQueryable()); + + var tt = Shim + .Replace(() => Is.A>().Get()) + .With((List list) => 20); + + PoseContext.Isolate( + () => + { + var xs = new int[] { 0, 1, 2 }; + Console.WriteLine("X: " + xs.Length); + Console.WriteLine("Y: " + xs.Count()); + + var iis = Program.GetInts(); + Console.WriteLine(iis.LongCount()); + + //Console.WriteLine("X: " + Program.GetInts().Count()); + }, getIntsShim, countShim); + return; + #if NET48 Console.WriteLine("4.8"); var dateTimeShim = Shim.Replace(() => DateTime.Now).With(() => new DateTime(2004, 1, 1)); diff --git a/test/Pose.Tests/ShimTests.cs b/test/Pose.Tests/ShimTests.cs index ba816fc..82c89c5 100644 --- a/test/Pose.Tests/ShimTests.cs +++ b/test/Pose.Tests/ShimTests.cs @@ -13,8 +13,245 @@ namespace Pose.Tests { + internal static class InstanceExtensions + { + public static string GetString(this ShimTests.ExtensionMethods.ReferenceTypes.Instance instance) => null; + + public static string GetString(this ShimTests.ExtensionMethods.ValueTypes.Instance instance) => null; + + public static string GetString(this ShimTests.ExtensionMethods.SealedTypes.Instance instance) => null; + } + public class ShimTests { + public class ExplicitInterfaceImplementations + { + public class ReferenceTypes + { + private interface IInterface + { + string Text2 { get;} + + string GetText(); + } + + private class Instance : IInterface + { + public string Text { get; set; } + + string IInterface.Text2 => "Hey"; + + string IInterface.GetText() => "Hey"; + } + + [Fact] + public void Can_shim_explicit_interface_property() + { + // Arrange + const string configuredValue = "Hello"; + var action = new Func((IInterface @this) => configuredValue); + var shim = Shim.Replace(() => Is.A().Text2).With(action); + var instance = new Instance(); + + // Act + string value1 = default; + PoseContext.Isolate( + () => + { + value1 = ((IInterface)instance).Text2; + }, shim); + + // Assert + value1.Should().BeEquivalentTo(configuredValue, because: "the shim is configured for any instance"); + + IInterface secondInstance = new Instance(); + secondInstance.Text2.Should().NotBeEquivalentTo(value1, because: "this instance is created outside the isolated code"); + } + + [Fact] + public void Can_shim_explicit_interface_method() + { + // Arrange + const string configuredValue = "Hello"; + var action = new Func((IInterface @this) => configuredValue); + var shim = Shim.Replace(() => Is.A().GetText()).With(action); + var instance = new Instance(); + + // Act + string value1 = default; + PoseContext.Isolate( + () => + { + value1 = ((IInterface)instance).GetText(); + }, shim); + + // Assert + value1.Should().BeEquivalentTo(configuredValue, because: "the shim is configured for any instance"); + + IInterface secondInstance = new Instance(); + secondInstance.GetText().Should().NotBeEquivalentTo(value1, because: "this instance is created outside the isolated code"); + } + } + + public class ValueTypes + { + private interface IInterface + { + string Text2 { get;} + + string GetText(); + } + + private struct InstanceValue : IInterface + { + public string Text { get; set; } + + string IInterface.Text2 => "Hey"; + + string IInterface.GetText() => "Hey"; + } + + [Fact] + public void Can_shim_explicit_interface_property() + { + // Arrange + const string configuredValue = "Hello"; + var action = new Func((IInterface @this) => configuredValue); + var shim = Shim.Replace(() => Is.A().Text2).With(action); + var instance = new InstanceValue(); + + // Act + string value1 = default; + PoseContext.Isolate( + () => + { + value1 = ((IInterface)instance).Text2; + }, shim); + + // Assert + value1.Should().BeEquivalentTo(configuredValue, because: "the shim is configured for any instance"); + + IInterface secondInstance = new InstanceValue(); + secondInstance.Text2.Should().NotBeEquivalentTo(value1, because: "this instance is created outside the isolated code"); + } + + [Fact] + public void Can_shim_explicit_interface_method() + { + // Arrange + const string configuredValue = "Hello"; + var action = new Func((IInterface @this) => configuredValue); + var shim = Shim.Replace(() => Is.A().GetText()).With(action); + var instance = new InstanceValue(); + + // Act + string value1 = default; + PoseContext.Isolate( + () => + { + value1 = ((IInterface)instance).GetText(); + }, shim); + + // Assert + value1.Should().BeEquivalentTo(configuredValue, because: "the shim is configured for any instance"); + + IInterface secondInstance = new InstanceValue(); + secondInstance.GetText().Should().NotBeEquivalentTo(value1, because: "this instance is created outside the isolated code"); + } + } + } + + public class ExtensionMethods + { + public class ReferenceTypes + { + internal class Instance { } + + [Fact] + public void Can_shim_extension_method_of_reference_type() + { + // Arrange + const string configuredValue = "String"; + var action = new Func((Instance @this) => configuredValue); + var shim = Shim.Replace(() => Is.A().GetString()).With(action); + + // Act + string dt = default; + PoseContext.Isolate( + () => + { + var instance = new Instance(); + dt = instance.GetString(); + }, + shim + ); + + // Assert + dt.Should().BeEquivalentTo(configuredValue, because: "that is what the shim is configured to return"); + } + } + + public class ValueTypes + { + internal struct Instance { } + + [Fact] + public void Can_shim_extension_method_of_value_type() + { + // Arrange + const string configuredValue = "String"; + var action = new Func((Instance @this) => configuredValue); + var shim = Shim + .Replace(() => Is.A().GetString()) + .With(action); + + // Act + string value = default; + PoseContext.Isolate( + () => + { + value = new Instance().GetString(); + }, + shim + ); + + // Assert + value.Should().BeEquivalentTo(configuredValue, because: "that is what the shim is configured to return"); + } + } + + public class SealedTypes + { + internal sealed class Instance { } + + [Fact] + public void Can_shim_method_of_sealed_class() + { + // Arrange + const string configuredValue = "String"; + var action = new Func((Instance @this) => configuredValue); + var shim = Shim.Replace(() => Is.A().GetString()).With(action); + + // Act + string dt = default; + PoseContext.Isolate( + () => + { + var instance = new Instance(); + dt = instance.GetString(); + }, + shim + ); + + // Assert + dt.Should().BeEquivalentTo(configuredValue, because: "that is what the shim is configured to return"); + + var sealedClass = new Instance(); + dt.Should().NotBeEquivalentTo(sealedClass.GetString(), because: "that is the original value"); + } + } + } + public class Methods { public class StaticTypes