From 09a9d673322ebf4ca2fcb7d48f13a44e015ea33f Mon Sep 17 00:00:00 2001 From: Yong He Date: Mon, 9 Dec 2024 04:48:03 -0800 Subject: [PATCH] Allow pointers to existential values. (#5793) * Fix pointer offset logic and add executable tests. * Fix. * Fix test. * Add existential ptr test. * Allow pointers to existential values. * Fix. * Fix. --------- Co-authored-by: Ellie Hermaszewska --- source/slang/slang-check-decl.cpp | 11 +++++ source/slang/slang-check-expr.cpp | 47 ++++++++++++++----- source/slang/slang-check-impl.h | 19 ++++++-- source/slang/slang-check-type.cpp | 5 +- source/slang/slang-check.h | 1 + .../slang-ir-lower-buffer-element-type.cpp | 22 --------- tests/bugs/gh-3825.slang | 1 - tests/spirv/existential-ptr.slang | 37 +++++++++++++++ tests/spirv/ptr-member-func.slang | 29 ++++++++++++ tests/spirv/ptr-unsized-array-3.slang | 29 ++++++++++++ tests/spirv/ptr-unsized-array-4.slang | 25 ++++++++++ tools/render-test/render-test-main.cpp | 32 ++++++++++++- 12 files changed, 215 insertions(+), 43 deletions(-) create mode 100644 tests/spirv/existential-ptr.slang create mode 100644 tests/spirv/ptr-member-func.slang create mode 100644 tests/spirv/ptr-unsized-array-3.slang create mode 100644 tests/spirv/ptr-unsized-array-4.slang diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index 4a4ade047d..eeb75e3fd1 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -3105,6 +3105,17 @@ Type* unwrapArrayType(Type* type) } } +Type* unwrapModifiedType(Type* type) +{ + for (;;) + { + if (auto modType = as(type)) + type = modType->getBase(); + else + return type; + } +} + void discoverExtensionDecls(List& decls, Decl* parent) { if (auto extDecl = as(parent)) diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index 2840cdd394..1f2776ba01 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -2307,7 +2307,10 @@ Expr* SemanticsVisitor::CheckSimpleSubscriptExpr(IndexExpr* subscriptExpr, Type* Expr* SemanticsExprVisitor::visitIndexExpr(IndexExpr* subscriptExpr) { bool needDeref = false; - auto baseExpr = checkBaseForMemberExpr(subscriptExpr->baseExpression, needDeref); + auto baseExpr = checkBaseForMemberExpr( + subscriptExpr->baseExpression, + CheckBaseContext::Subscript, + needDeref); // If the base expression is a type, it means that this is an array declaration, // then we should disable short-circuit in case there is logical expression in @@ -2951,7 +2954,10 @@ Expr* SemanticsExprVisitor::visitInvokeExpr(InvokeExpr* expr) auto operatorName = getName("()"); bool needDeref = false; - expr->functionExpr = maybeInsertImplicitOpForMemberBase(expr->functionExpr, needDeref); + expr->functionExpr = maybeInsertImplicitOpForMemberBase( + expr->functionExpr, + CheckBaseContext::Member, + needDeref); LookupResult lookupResult = lookUpMember( m_astBuilder, @@ -4060,19 +4066,29 @@ void SemanticsExprVisitor::maybeCheckKnownBuiltinInvocation(Expr* invokeExpr) } } -Expr* SemanticsVisitor::MaybeDereference(Expr* inExpr) +Expr* SemanticsVisitor::maybeDereference(Expr* inExpr, CheckBaseContext checkBaseContext) { Expr* expr = inExpr; for (;;) { auto baseType = expr->type; + QualType elementType; if (auto pointerLikeType = as(baseType)) { - auto elementType = QualType(pointerLikeType->getElementType()); + elementType = QualType(pointerLikeType->getElementType()); elementType.isLeftValue = baseType.isLeftValue; elementType.hasReadOnlyOnTarget = baseType.hasReadOnlyOnTarget; elementType.isWriteOnly = baseType.isWriteOnly; - + } + else if (auto ptrType = as(baseType)) + { + if (checkBaseContext == CheckBaseContext::Subscript) + return expr; + elementType = QualType(ptrType->getValueType()); + elementType.isLeftValue = true; + } + if (elementType.type) + { auto derefExpr = m_astBuilder->create(); derefExpr->base = expr; derefExpr->type = elementType; @@ -4080,7 +4096,6 @@ Expr* SemanticsVisitor::MaybeDereference(Expr* inExpr) expr = derefExpr; continue; } - // Default case: just use the expression as-is return expr; } @@ -4751,7 +4766,7 @@ Expr* SemanticsExprVisitor::visitStaticMemberExpr(StaticMemberExpr* expr) expr->baseExpression = CheckTerm(expr->baseExpression); // Not sure this is needed -> but guess someone could do - expr->baseExpression = MaybeDereference(expr->baseExpression); + expr->baseExpression = maybeDereference(expr->baseExpression, CheckBaseContext::Member); // If the base of the member lookup has an interface type // *without* a suitable this-type substitution, then we are @@ -4779,9 +4794,12 @@ Expr* SemanticsVisitor::lookupMemberResultFailure( return expr; } -Expr* SemanticsVisitor::maybeInsertImplicitOpForMemberBase(Expr* baseExpr, bool& outNeedDeref) +Expr* SemanticsVisitor::maybeInsertImplicitOpForMemberBase( + Expr* baseExpr, + CheckBaseContext checkBaseContext, + bool& outNeedDeref) { - auto derefExpr = MaybeDereference(baseExpr); + auto derefExpr = maybeDereference(baseExpr, checkBaseContext); if (derefExpr != baseExpr) outNeedDeref = true; @@ -4834,11 +4852,15 @@ Expr* SemanticsVisitor::maybeInsertImplicitOpForMemberBase(Expr* baseExpr, bool& return baseExpr; } -Expr* SemanticsVisitor::checkBaseForMemberExpr(Expr* inBaseExpr, bool& outNeedDeref) +Expr* SemanticsVisitor::checkBaseForMemberExpr( + Expr* inBaseExpr, + CheckBaseContext checkBaseContext, + bool& outNeedDeref) { auto baseExpr = inBaseExpr; baseExpr = CheckTerm(baseExpr); - return maybeInsertImplicitOpForMemberBase(baseExpr, outNeedDeref); + + return maybeInsertImplicitOpForMemberBase(baseExpr, checkBaseContext, outNeedDeref); } Expr* SemanticsVisitor::checkGeneralMemberLookupExpr(MemberExpr* expr, Type* baseType) @@ -4861,7 +4883,8 @@ Expr* SemanticsVisitor::checkGeneralMemberLookupExpr(MemberExpr* expr, Type* bas Expr* SemanticsExprVisitor::visitMemberExpr(MemberExpr* expr) { bool needDeref = false; - expr->baseExpression = checkBaseForMemberExpr(expr->baseExpression, needDeref); + expr->baseExpression = + checkBaseForMemberExpr(expr->baseExpression, CheckBaseContext::Member, needDeref); if (!needDeref && as(expr) && !as(expr->baseExpression->type)) { diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h index 95ec872a5a..460e87cb9b 100644 --- a/source/slang/slang-check-impl.h +++ b/source/slang/slang-check-impl.h @@ -2654,8 +2654,6 @@ struct SemanticsVisitor : public SemanticsContext // // - Expr* MaybeDereference(Expr* inExpr); - Expr* CheckMatrixSwizzleExpr( MemberExpr* memberRefExpr, Type* baseElementType, @@ -2696,11 +2694,24 @@ struct SemanticsVisitor : public SemanticsContext /// Perform checking operations required for the "base" expression of a member-reference like /// `base.someField` - Expr* checkBaseForMemberExpr(Expr* baseExpr, bool& outNeedDeref); + enum class CheckBaseContext + { + Member, + Subscript, + }; + Expr* checkBaseForMemberExpr( + Expr* baseExpr, + CheckBaseContext checkBaseContext, + bool& outNeedDeref); + + Expr* maybeDereference(Expr* inExpr, CheckBaseContext checkBaseContext); /// Prepare baseExpr for use as the base of a member expr. /// This include inserting implicit open-existential operations as needed. - Expr* maybeInsertImplicitOpForMemberBase(Expr* baseExpr, bool& outNeedDeref); + Expr* maybeInsertImplicitOpForMemberBase( + Expr* baseExpr, + CheckBaseContext checkBaseContext, + bool& outNeedDeref); Expr* lookupMemberResultFailure( DeclRefExpr* expr, diff --git a/source/slang/slang-check-type.cpp b/source/slang/slang-check-type.cpp index d9691a8281..2c8f3d0c08 100644 --- a/source/slang/slang-check-type.cpp +++ b/source/slang/slang-check-type.cpp @@ -216,9 +216,10 @@ bool isManagedType(Type* type) { if (auto declRefValueType = as(type)) { - if (as(declRefValueType->getDeclRef().getDecl())) + auto decl = declRefValueType->getDeclRef().getDecl(); + if (as(decl)) return true; - if (as(declRefValueType->getDeclRef().getDecl())) + if (as(decl) && decl->findModifier()) return true; } return false; diff --git a/source/slang/slang-check.h b/source/slang/slang-check.h index bd2bdce415..f1392e9cec 100644 --- a/source/slang/slang-check.h +++ b/source/slang/slang-check.h @@ -24,6 +24,7 @@ bool isFromCoreModule(Decl* decl); void registerBuiltinDecls(Session* session, Decl* decl); Type* unwrapArrayType(Type* type); +Type* unwrapModifiedType(Type* type); OrderedDictionary> getCanonicalGenericConstraints( ASTBuilder* builder, diff --git a/source/slang/slang-ir-lower-buffer-element-type.cpp b/source/slang/slang-ir-lower-buffer-element-type.cpp index bd3e350bca..dd62ca02c2 100644 --- a/source/slang/slang-ir-lower-buffer-element-type.cpp +++ b/source/slang/slang-ir-lower-buffer-element-type.cpp @@ -901,28 +901,6 @@ struct LoweredElementTypeContext { builder.setInsertBefore(ptrVal); auto newArrayPtrVal = fieldAddr->getBase(); - // Is base a pointer to an empty struct? If so, don't offset it. - // For example, if the user has written: - // ``` - // struct S {int arr[]}; - // uniform S* p; - // void test() { p->arr[1]; } - // ``` - // Then `S` will become an empty struct after we remove `arr[]`. - // And `p` will be come a `void*`. - // We don't want to offset `p` to `p+1` to get the starting address of - // the array in this case. - IRSizeAndAlignment parentStructSize = {}; - getNaturalSizeAndAlignment( - target->getOptionSet(), - tryGetPointedToType(&builder, fieldAddr->getBase()->getDataType()), - &parentStructSize); - if (parentStructSize.size != 0) - { - newArrayPtrVal = builder.emitGetOffsetPtr( - fieldAddr->getBase(), - builder.getIntValue(builder.getIntType(), 1)); - } auto loweredInnerType = getLoweredTypeInfo(unsizedArrayType->getElementType(), layoutRules); diff --git a/tests/bugs/gh-3825.slang b/tests/bugs/gh-3825.slang index c7c325864a..5953a858b4 100644 --- a/tests/bugs/gh-3825.slang +++ b/tests/bugs/gh-3825.slang @@ -21,7 +21,6 @@ float4 fragment(): SV_Target } // CHECK: OpDecorate %_ptr_PhysicalStorageBuffer_Descriptors_natural ArrayStride 4 -// CHECK: %{{.*}} = OpPtrAccessChain %_ptr_PhysicalStorageBuffer_Descriptors_natural %{{.*}} %int_1 // CHECK: OpBitcast %ulong // CHECK: OpIAdd %ulong %{{.*}} %ulong_4 // CHECK: OpBitcast %_ptr_PhysicalStorageBuffer \ No newline at end of file diff --git a/tests/spirv/existential-ptr.slang b/tests/spirv/existential-ptr.slang new file mode 100644 index 0000000000..66f1c64a29 --- /dev/null +++ b/tests/spirv/existential-ptr.slang @@ -0,0 +1,37 @@ +//TEST:COMPARE_COMPUTE(filecheck-buffer=CHECK): -vk -emit-spirv-directly -output-using-type +//DISABLED_TEST:COMPARE_COMPUTE(filecheck-buffer=CHECK): -wgpu +//DISABLED_TEST:COMPARE_COMPUTE(filecheck-buffer=CHECK): -d3d12 +//DISABLED_TEST:COMPARE_COMPUTE(filecheck-buffer=CHECK): -d3d11 +//DISABLED_TEST:COMPARE_COMPUTE(filecheck-buffer=CHECK): -metal + +interface IFoo +{ + int getVal(); +} + +struct Foo : IFoo +{ + int val; + int getVal() { return val; } +} + +struct Bar : IFoo +{ + float val; + int getVal() { return (int)val + 1; } +} + +//TEST_INPUT: set pFoo = ubuffer(data=[0 0 2 0 2.0f], stride=4); +//TEST_INPUT: type_conformance Foo:IFoo = 1; +//TEST_INPUT: type_conformance Bar:IFoo = 2; +uniform IFoo* pFoo; + +//TEST_INPUT: set outputBuffer = out ubuffer(data=[0 0 0 0], stride=4); +RWStructuredBuffer outputBuffer; + +[numthreads(1,1,1)] +void computeMain() +{ + // CHECK: 3.0 + outputBuffer[0] = pFoo->getVal(); +} \ No newline at end of file diff --git a/tests/spirv/ptr-member-func.slang b/tests/spirv/ptr-member-func.slang new file mode 100644 index 0000000000..0dcf572ee9 --- /dev/null +++ b/tests/spirv/ptr-member-func.slang @@ -0,0 +1,29 @@ +//TEST:COMPARE_COMPUTE(filecheck-buffer=CHECK): -vk -emit-spirv-directly +//DISABLED_TEST: COMPARE_COMPUTE(filecheck-buffer=CHECK): -wgpu +//DISABLED_TEST: COMPARE_COMPUTE(filecheck-buffer=CHECK): -d3d11 +//DISABLED_TEST: COMPARE_COMPUTE(filecheck-buffer=CHECK): -d3d12 +//DISABLED_TEST: COMPARE_COMPUTE(filecheck-buffer=CHECK): -metal + +struct Obj +{ + int val; + + [mutating] + void addOne() { val++; } + + int getValPlusOne() { return val + 1; } +} + +//TEST_INPUT: set pObj = ubuffer(data=[2 0 0 0], stride=4); +uniform Obj* pObj; + +//TEST_INPUT: set outputBuffer = out ubuffer(data=[0 0 0 0],stride=4); +uniform RWStructuredBuffer outputBuffer; + +[numthreads(1,1,1)] +void computeMain() +{ + pObj->addOne(); + // CHECK: 4 + outputBuffer[0] = pObj->getValPlusOne(); +} \ No newline at end of file diff --git a/tests/spirv/ptr-unsized-array-3.slang b/tests/spirv/ptr-unsized-array-3.slang new file mode 100644 index 0000000000..ffd1345ea5 --- /dev/null +++ b/tests/spirv/ptr-unsized-array-3.slang @@ -0,0 +1,29 @@ +//TEST:COMPARE_COMPUTE(filecheck-buffer=CHECK): -vk -emit-spirv-directly +//DISABLED_TEST: COMPARE_COMPUTE(filecheck-buffer=CHECK): -wgpu +//DISABLED_TEST: COMPARE_COMPUTE(filecheck-buffer=CHECK): -d3d11 +//DISABLED_TEST: COMPARE_COMPUTE(filecheck-buffer=CHECK): -d3d12 +//DISABLED_TEST: COMPARE_COMPUTE(filecheck-buffer=CHECK): -metal + +// Test a pointer to a struct with a trailing unsized array. + +struct MeshStorage { + int foo; + uint64_t QuadData[]; +}; + +//TEST_INPUT: set pStorage = ubuffer(data=[1 2 3 4 5 6 7 8],stride=4); +uniform MeshStorage* pStorage; + +//TEST_INPUT: set outputBuffer = out ubuffer(data=[0 0 0 0],stride=4); +uniform RWStructuredBuffer outputBuffer; + +[numthreads(1,1,1)] +void computeMain() +{ + // CHECK: 5 + // CHECK: 6 + // CHECK: 1 + outputBuffer[0] = (int)(pStorage.QuadData[1]&0xFFFFFFFF); + outputBuffer[1] = (int)(pStorage.QuadData[1]>>32); + outputBuffer[2] = pStorage.foo; +} \ No newline at end of file diff --git a/tests/spirv/ptr-unsized-array-4.slang b/tests/spirv/ptr-unsized-array-4.slang new file mode 100644 index 0000000000..561dfab22f --- /dev/null +++ b/tests/spirv/ptr-unsized-array-4.slang @@ -0,0 +1,25 @@ +//TEST:COMPARE_COMPUTE(filecheck-buffer=CHECK): -vk -emit-spirv-directly +//DISABLED_TEST: COMPARE_COMPUTE(filecheck-buffer=CHECK): -wgpu +//DISABLED_TEST: COMPARE_COMPUTE(filecheck-buffer=CHECK): -d3d11 +//DISABLED_TEST: COMPARE_COMPUTE(filecheck-buffer=CHECK): -d3d12 +//DISABLED_TEST: COMPARE_COMPUTE(filecheck-buffer=CHECK): -metal + +// Test a pointer to a struct that has only one field and is an unsized array. +struct MeshStorage { + uint64_t QuadData[]; +}; + +//TEST_INPUT: set pStorage = ubuffer(data=[1 2 3 4 5 6 7 8],stride=4); +uniform MeshStorage* pStorage; + +//TEST_INPUT: set outputBuffer = out ubuffer(data=[0 0 0 0],stride=4); +uniform RWStructuredBuffer outputBuffer; + +[numthreads(1,1,1)] +void computeMain() +{ + // CHECK: 3 + // CHECK: 4 + outputBuffer[0] = (int)(pStorage.QuadData[1]&0xFFFFFFFF); + outputBuffer[1] = (int)(pStorage.QuadData[1]>>32); +} \ No newline at end of file diff --git a/tools/render-test/render-test-main.cpp b/tools/render-test/render-test-main.cpp index 2e07a76894..5907be66d5 100644 --- a/tools/render-test/render-test-main.cpp +++ b/tools/render-test/render-test-main.cpp @@ -76,6 +76,12 @@ struct ShaderOutputPlan List items; }; +// A context for hodling resources allocated for a test. +struct TestResourceContext +{ + List> resources; +}; + class RenderTestApp { public: @@ -134,6 +140,7 @@ class RenderTestApp Options m_options; ShaderOutputPlan m_outputPlan; + TestResourceContext m_resourceContext; }; struct AssignValsFromLayoutContext @@ -141,6 +148,7 @@ struct AssignValsFromLayoutContext IDevice* device; slang::ISession* slangSession; ShaderOutputPlan& outputPlan; + TestResourceContext& resourceContext; slang::ProgramLayout* slangReflection; IAccelerationStructure* accelerationStructure; @@ -148,11 +156,13 @@ struct AssignValsFromLayoutContext IDevice* device, slang::ISession* slangSession, ShaderOutputPlan& outputPlan, + TestResourceContext& resourceContext, slang::ProgramLayout* slangReflection, IAccelerationStructure* accelerationStructure) : device(device) , slangSession(slangSession) , outputPlan(outputPlan) + , resourceContext(resourceContext) , slangReflection(slangReflection) , accelerationStructure(accelerationStructure) { @@ -204,6 +214,7 @@ struct AssignValsFromLayoutContext bufferData.add(0); ComPtr bufferResource; + SLANG_RETURN_ON_FAIL(ShaderRendererUtil::createBuffer( srcBuffer, /*entry.isOutput,*/ bufferSize, @@ -211,6 +222,16 @@ struct AssignValsFromLayoutContext device, bufferResource)); + if (dstCursor.getTypeLayout()->getType()->getKind() == slang::TypeReflection::Kind::Pointer) + { + // dstCursor is pointer to an ordinary uniform data field, + // we should write bufferResource as a pointer. + uint64_t addr = bufferResource->getDeviceAddress(); + dstCursor.setData(&addr, sizeof(addr)); + resourceContext.resources.add(ComPtr(bufferResource.get())); + return SLANG_OK; + } + ComPtr counterResource; const auto explicitCounterCursor = dstCursor.getExplicitCounter(); if (srcBuffer.counter != ~0u) @@ -488,11 +509,17 @@ SlangResult _assignVarsFromLayout( IShaderObject* shaderObject, ShaderInputLayout const& layout, ShaderOutputPlan& ioOutputPlan, + TestResourceContext& ioResourceContext, slang::ProgramLayout* slangReflection, IAccelerationStructure* accelerationStructure) { - AssignValsFromLayoutContext - context(device, slangSession, ioOutputPlan, slangReflection, accelerationStructure); + AssignValsFromLayoutContext context( + device, + slangSession, + ioOutputPlan, + ioResourceContext, + slangReflection, + accelerationStructure); ShaderCursor rootCursor = ShaderCursor(shaderObject); return context.assign(rootCursor, layout.rootVal); } @@ -510,6 +537,7 @@ Result RenderTestApp::applyBinding(IShaderObject* rootObject) rootObject, m_compilationOutput.layout, m_outputPlan, + m_resourceContext, slangReflection, m_topLevelAccelerationStructure); }