diff --git a/src/Authoring/WinRT.SourceGenerator/WinRTTypeWriter.cs b/src/Authoring/WinRT.SourceGenerator/WinRTTypeWriter.cs index 27482d6ef..e74c05929 100644 --- a/src/Authoring/WinRT.SourceGenerator/WinRTTypeWriter.cs +++ b/src/Authoring/WinRT.SourceGenerator/WinRTTypeWriter.cs @@ -2726,7 +2726,20 @@ bool IsWinRTType(ISymbol symbol, TypeMapper mapper) { // Interfaces which are allowed to be implemented on authored types but // aren't WinRT interfaces. - return !ImplementedInterfacesWithoutMapping.Contains(QualifiedName(namedType)); + bool isMapped = ImplementedInterfacesWithoutMapping.Contains(QualifiedName(namedType)); + + if (isMapped) + return false; + + // If the interface is publicly accessible, it is a WinRT interface. + bool isPublic = namedType.IsPubliclyAccessible(); + + if (isPublic) + return true; + + // If it's not a publicly accessible interface, it's a WinRT interface if it has the + // [WindowsRuntimeType] attribute. + return namedType.GetAttributes().Any(static attribute => string.CompareOrdinal(attribute.AttributeClass.Name, "WindowsRuntimeTypeAttribute") == 0); } return namedType.SpecialType != SpecialType.System_Object && diff --git a/src/Tests/AuthoringConsumptionTest/test.cpp b/src/Tests/AuthoringConsumptionTest/test.cpp index 1b416fb1a..ea879b7c7 100644 --- a/src/Tests/AuthoringConsumptionTest/test.cpp +++ b/src/Tests/AuthoringConsumptionTest/test.cpp @@ -683,6 +683,13 @@ TEST(AuthoringTest, MixedWinRTClassicCOM) winrt::com_ptr<::IUnknown> internalInterface2; EXPECT_EQ(unknown2->QueryInterface(internalInterface2Iid, internalInterface2.put_void()), S_OK); + // Verify we can grab the generated COM interface + IID internalInterface3Iid; + check_hresult(IIDFromString(L"{6234C2F7-9917-469F-BDB4-3E8C630598AF}", &internalInterface3Iid)); + winrt::com_ptr<::IUnknown> unknown3 = wrapper.as<::IUnknown>(); + winrt::com_ptr<::IUnknown> internalInterface3; + EXPECT_EQ(unknown3->QueryInterface(internalInterface3Iid, internalInterface3.put_void()), S_OK); + typedef int (__stdcall* GetNumber)(void*, int*); int number; @@ -694,6 +701,10 @@ TEST(AuthoringTest, MixedWinRTClassicCOM) // Validate the second call on IInternalInterface2 EXPECT_EQ(reinterpret_cast((*reinterpret_cast(internalInterface2.get()))[3])(internalInterface2.get(), &number), S_OK); EXPECT_EQ(number, 123); + + // Validate the third call on IInternalInterface3 + EXPECT_EQ(reinterpret_cast((*reinterpret_cast(internalInterface3.get()))[3])(internalInterface3.get(), &number), S_OK); + EXPECT_EQ(number, 1); } TEST(AuthoringTest, GetRuntimeClassName) diff --git a/src/Tests/AuthoringTest/Program.cs b/src/Tests/AuthoringTest/Program.cs index 9910308bb..18ae28a63 100644 --- a/src/Tests/AuthoringTest/Program.cs +++ b/src/Tests/AuthoringTest/Program.cs @@ -1837,7 +1837,8 @@ public void Dispose() } } - public sealed class TestMixedWinRTCOMWrapper : IGraphicsEffectSource, IPublicInterface, IInternalInterface1, SomeInternalType.IInternalInterface2 + [System.Runtime.InteropServices.Marshalling.GeneratedComClass] + public sealed partial class TestMixedWinRTCOMWrapper : IGraphicsEffectSource, IPublicInterface, IInternalInterface1, SomeInternalType.IInternalInterface2, IInternalInterface3 { public string HelloWorld() { @@ -1857,6 +1858,11 @@ unsafe int SomeInternalType.IInternalInterface2.GetNumber(int* value) return 0; } + + int IInternalInterface3.GetNumber() + { + return 1; + } } public interface IPublicInterface @@ -1953,6 +1959,13 @@ private static int GetNumberFromAbi(void* thisPtr, int* value) } } + [System.Runtime.InteropServices.Guid("6234C2F7-9917-469F-BDB4-3E8C630598AF")] + [System.Runtime.InteropServices.Marshalling.GeneratedComInterface] + internal partial interface IInternalInterface3 + { + int GetNumber(); + } + [System.Runtime.InteropServices.Guid("26D8EE57-8B1B-46F4-A4F9-8C6DEEEAF53A")] public interface ICustomInterfaceGuid { diff --git a/src/WinRT.Runtime/ComWrappersSupport.cs b/src/WinRT.Runtime/ComWrappersSupport.cs index 7ea4e9029..5d72ff088 100644 --- a/src/WinRT.Runtime/ComWrappersSupport.cs +++ b/src/WinRT.Runtime/ComWrappersSupport.cs @@ -210,6 +210,24 @@ internal static List GetInterfaceTableEntries(Type type) } } +#if NET8_0_OR_GREATER + var comExposedDetails = System.Runtime.InteropServices.Marshalling.StrategyBasedComWrappers.DefaultIUnknownInterfaceDetailsStrategy.GetComExposedTypeDetails(type.TypeHandle); + + if (comExposedDetails != null) + { + ReadOnlySpan comInterfaceEntries; + unsafe + { + ComInterfaceEntry* entriesPointer = comExposedDetails.GetComInterfaceEntries(out int generatedEntriesCount); + comInterfaceEntries = new ReadOnlySpan(entriesPointer, generatedEntriesCount); + } + foreach (var entry in comInterfaceEntries) + { + entries.Add(entry); + } + } +#endif + if (winrtExposedClassAttribute != null) { hasWinrtExposedClassAttribute = true;