From 48b89fc964c4a8cfdb45e7283a0e5c90c02012d7 Mon Sep 17 00:00:00 2001 From: Tim Corringham Date: Tue, 24 Sep 2024 14:49:42 +0100 Subject: [PATCH 1/3] Support SV_DispatchGrid semantic in a nested record The SV_DispatchGrid DXIL metadata for a node input record was not generated in cases where: - the field with the SV_DispatchGrid semantic was in a nested record - the field with the SV_DispatchGrid semantic was in a record field - the field with the SV_DispatchGrid semantic was inherited from a base record - in any combinations of the above Added FindDispatchGridSemantic() to be used by the AddHLSLNodeRecordTypeInfo() function, and added a test case. Fixes #6928 --- tools/clang/lib/CodeGen/CGHLSLMS.cpp | 130 ++++++++++-------- .../workgraph/nested_sv_dispatchgrid.hlsl | 78 +++++++++++ 2 files changed, 153 insertions(+), 55 deletions(-) create mode 100644 tools/clang/test/HLSLFileCheck/hlsl/workgraph/nested_sv_dispatchgrid.hlsl diff --git a/tools/clang/lib/CodeGen/CGHLSLMS.cpp b/tools/clang/lib/CodeGen/CGHLSLMS.cpp index 72f5a791ab..34e294aba9 100644 --- a/tools/clang/lib/CodeGen/CGHLSLMS.cpp +++ b/tools/clang/lib/CodeGen/CGHLSLMS.cpp @@ -288,6 +288,9 @@ class CGMSHLSLRuntime : public CGHLSLRuntime { llvm::Value *DestPtr, clang::QualType DestTy) override; void AddHLSLFunctionInfo(llvm::Function *, const FunctionDecl *FD) override; + bool FindDispatchGridSemantic(const CXXRecordDecl *RD, + hlsl::SVDispatchGrid &SDGRec, + CharUnits Offset = CharUnits()); void AddHLSLNodeRecordTypeInfo(const clang::ParmVarDecl *parmDecl, hlsl::NodeIOProperties &node); void EmitHLSLFunctionProlog(llvm::Function *, @@ -2558,6 +2561,75 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) { m_ScopeMap[F] = ScopeInfo(F, FD->getLocation()); } +// Find the input node record field with the SV_DispatchGrid semantic. +// We have already diagnosed any error conditions in Sema, so we +// expect valid size and types, and use the first occurance found. +// We return true if we have populated the SV_DispatchGrid values. +bool CGMSHLSLRuntime::FindDispatchGridSemantic(const CXXRecordDecl *RD, + hlsl::SVDispatchGrid &SDGRec, + CharUnits Offset) { + const ASTRecordLayout &Layout = CGM.getContext().getASTRecordLayout(RD); + + // Collect any non-virtual bases. + SmallVector Bases; + for (const CXXBaseSpecifier &Base : RD->bases()) { + if (!Base.isVirtual() && !Base.getType()->isDependentType()) + Bases.push_back(Base.getType()->getAsCXXRecordDecl()); + } + + // Sort bases by offset. + std::stable_sort(Bases.begin(), Bases.end(), + [&](const CXXRecordDecl *L, const CXXRecordDecl *R) { + return Layout.getBaseClassOffset(L) < + Layout.getBaseClassOffset(R); + }); + + // Check (non-virtual) bases + for (const CXXRecordDecl *Base : Bases) { + CharUnits BaseOffset = Offset + Layout.getBaseClassOffset(Base); + if (FindDispatchGridSemantic(Base, SDGRec, BaseOffset)) + return true; + } + + // Check each field in this record. + for (FieldDecl *Field : RD->fields()) { + uint64_t FieldNo = Field->getFieldIndex(); + CharUnits FieldOffset = Offset + CGM.getContext().toCharUnitsFromBits( + Layout.getFieldOffset(FieldNo)); + + // If this field is a record check its fields + if (const CXXRecordDecl *D = Field->getType()->getAsCXXRecordDecl()) { + if (FindDispatchGridSemantic(D, SDGRec, FieldOffset)) + return true; + } + // Otherwise check this field for the SV_DispatchGrid semantic annotation + for (const hlsl::UnusualAnnotation *it : Field->getUnusualAnnotations()) { + if (it->getKind() == hlsl::UnusualAnnotation::UA_SemanticDecl) { + const hlsl::SemanticDecl *sd = cast(it); + if (sd->SemanticName.equals("SV_DispatchGrid")) { + const llvm::Type *FTy = CGM.getTypes().ConvertType(Field->getType()); + const llvm::Type *ElTy = FTy; + SDGRec.NumComponents = 1; + SDGRec.ByteOffset = (unsigned)FieldOffset.getQuantity(); + if (const llvm::VectorType *VT = dyn_cast(FTy)) { + SDGRec.NumComponents = VT->getNumElements(); + ElTy = VT->getElementType(); + } else if (const llvm::ArrayType *AT = + dyn_cast(FTy)) { + SDGRec.NumComponents = AT->getNumElements(); + ElTy = AT->getElementType(); + } + SDGRec.ComponentType = (ElTy->getIntegerBitWidth() == 16) + ? DXIL::ComponentType::U16 + : DXIL::ComponentType::U32; + return true; + } + } + } + } + return false; +} + void CGMSHLSLRuntime::AddHLSLNodeRecordTypeInfo( const clang::ParmVarDecl *parmDecl, hlsl::NodeIOProperties &node) { clang::QualType paramTy = parmDecl->getType().getCanonicalType(); @@ -2575,7 +2647,6 @@ void CGMSHLSLRuntime::AddHLSLNodeRecordTypeInfo( DiagnosticsEngine &Diags = CGM.getDiags(); auto &Rec = TemplateArgs.get(0); clang::QualType RecType = Rec.getAsType(); - llvm::Type *Type = CGM.getTypes().ConvertType(RecType); CXXRecordDecl *RD = RecType->getAsCXXRecordDecl(); // Get the TrackRWInputSharing flag from the record attribute @@ -2595,63 +2666,12 @@ void CGMSHLSLRuntime::AddHLSLNodeRecordTypeInfo( // Ex: For DispatchNodeInputRecord, set size = // size(MY_RECORD), alignment = alignof(MY_RECORD) + llvm::Type *Type = CGM.getTypes().ConvertType(RecType); node.RecordType.size = CGM.getDataLayout().getTypeAllocSize(Type); node.RecordType.alignment = CGM.getDataLayout().getABITypeAlignment(Type); - // Iterate over fields of the MY_RECORD(example) struct - for (auto fieldDecl : RD->fields()) { - // Check if any of the fields have a semantic annotation = - // SV_DispatchGrid - for (const hlsl::UnusualAnnotation *it : - fieldDecl->getUnusualAnnotations()) { - if (it->getKind() == hlsl::UnusualAnnotation::UA_SemanticDecl) { - const hlsl::SemanticDecl *sd = cast(it); - // if we find a field with SV_DispatchGrid, fill out the - // SV_DispatchGrid member with byteoffset of the field, - // NumComponents (3 for uint3 etc) and U32 vs U16 types, which are - // the only types allowed - if (sd->SemanticName.equals("SV_DispatchGrid")) { - clang::QualType FT = fieldDecl->getType(); - auto &DL = CGM.getDataLayout(); - auto &SDGRec = node.RecordType.SV_DispatchGrid; - - DXASSERT_NOMSG(SDGRec.NumComponents == 0); - - unsigned fieldIdx = fieldDecl->getFieldIndex(); - if (StructType *ST = dyn_cast(Type)) { - SDGRec.ByteOffset = - DL.getStructLayout(ST)->getElementOffset(fieldIdx); - } - const llvm::Type *lTy = CGM.getTypes().ConvertType(FT); - if (const llvm::VectorType *VT = - dyn_cast(lTy)) { - DXASSERT(VT->getElementType()->isIntegerTy(), "invalid type"); - SDGRec.NumComponents = VT->getNumElements(); - SDGRec.ComponentType = - (VT->getElementType()->getIntegerBitWidth() == 16) - ? DXIL::ComponentType::U16 - : DXIL::ComponentType::U32; - } else if (const llvm::ArrayType *AT = - dyn_cast(lTy)) { - DXASSERT(AT->getElementType()->isIntegerTy(), "invalid type"); - DXASSERT_NOMSG(AT->getNumElements() <= 3); - SDGRec.NumComponents = AT->getNumElements(); - SDGRec.ComponentType = - (AT->getElementType()->getIntegerBitWidth() == 16) - ? DXIL::ComponentType::U16 - : DXIL::ComponentType::U32; - } else { - // Scalar U16 or U32 - DXASSERT(lTy->isIntegerTy(), "invalid type"); - SDGRec.NumComponents = 1; - SDGRec.ComponentType = (lTy->getIntegerBitWidth() == 16) - ? DXIL::ComponentType::U16 - : DXIL::ComponentType::U32; - } - } - } - } - } + + FindDispatchGridSemantic(RD, node.RecordType.SV_DispatchGrid); } } } diff --git a/tools/clang/test/HLSLFileCheck/hlsl/workgraph/nested_sv_dispatchgrid.hlsl b/tools/clang/test/HLSLFileCheck/hlsl/workgraph/nested_sv_dispatchgrid.hlsl new file mode 100644 index 0000000000..5e9cb47cf1 --- /dev/null +++ b/tools/clang/test/HLSLFileCheck/hlsl/workgraph/nested_sv_dispatchgrid.hlsl @@ -0,0 +1,78 @@ +// RUN: %dxc -T lib_6_8 %s | FileCheck %s + +// Check that the SV_DispatchGrid DXIL metadata for a node input record is +// generated in cases where: +// node1 - the field with the SV_DispatchGrid semantic is in a nested record +// node2 - the field with the SV_DispatchGrid semantic is in a record field +// node3 - the field with the SV_DispatchGrid semantic is inherited from a base record +// node4 - the field with the SV_DispatchGrid semantic is within a nested record inherited from a base record +// node5 - the field with the SV_DispatchGrid semantic is within a base record of a nested record + +struct Record1 { + struct { + // SV_DispatchGrid is within a nested record + uint3 grid : SV_DispatchGrid; + }; +}; + +[Shader("node")] +[NodeMaxDispatchGrid(32,16,1)] +[NumThreads(32,1,1)] +void node1(DispatchNodeInputRecord input) {} +// CHECK: , i32 1, ![[SVDG_1:[0-9]+]] +// CHECK: [[SVDG_1]] = !{i32 0, i32 5, i32 3} + +struct Record2a { + uint u; + uint2 grid : SV_DispatchGrid; +}; + +struct Record2 { + uint a; + // SV_DispatchGrid is within a record field + Record2a b; +}; + +[Shader("node")] +[NodeMaxDispatchGrid(32,16,1)] +[NumThreads(32,1,1)] +void node2(DispatchNodeInputRecord input) {} +// CHECK: , i32 1, ![[SVDG_2:[0-9]+]] +// CHECK: [[SVDG_2]] = !{i32 8, i32 5, i32 2} + +struct Record3 : Record2a { + // SV_DispatchGrid is inherited + uint4 n; +}; + +[Shader("node")] +[NodeMaxDispatchGrid(32,16,1)] +[NumThreads(32,1,1)] +void node3(DispatchNodeInputRecord input) {} +// CHECK: , i32 1, ![[SVDG_3:[0-9]+]] +// CHECK: [[SVDG_3]] = !{i32 4, i32 5, i32 2} + +struct Record4 : Record2 { + // SV_DispatchGrid is in a nested field in a base record + float f; +}; + +[Shader("node")] +[NodeMaxDispatchGrid(32,16,1)] +[NumThreads(32,1,1)] +void node4(DispatchNodeInputRecord input) {} +// CHECK: , i32 1, ![[SVDG_2]] + +struct Record5 { + uint4 x; + // SV_DispatchGrid is in a base record of a record field + Record3 r; +}; + +[Shader("node")] +[NodeLaunch("broadcasting")] +[NodeMaxDispatchGrid(32,16,1)] +[NumThreads(32,1,1)] +void node5(DispatchNodeInputRecord input) {} +// CHECK: , i32 1, ![[SVDG_5:[0-9]+]] +// CHECK: [[SVDG_5]] = !{i32 20, i32 5, i32 2} From 4c74ce184c220b2bd69e72506cbd7f7db53620c0 Mon Sep 17 00:00:00 2001 From: Tim Corringham Date: Tue, 10 Dec 2024 17:31:34 +0000 Subject: [PATCH 2/3] Address review comments Remove the check for a virtual base class from the code in FindDispatchGridSemantic() as virtual classes can't appear in HLSL code. --- tools/clang/lib/CodeGen/CGHLSLMS.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tools/clang/lib/CodeGen/CGHLSLMS.cpp b/tools/clang/lib/CodeGen/CGHLSLMS.cpp index 34e294aba9..79c3a72a31 100644 --- a/tools/clang/lib/CodeGen/CGHLSLMS.cpp +++ b/tools/clang/lib/CodeGen/CGHLSLMS.cpp @@ -2570,10 +2570,10 @@ bool CGMSHLSLRuntime::FindDispatchGridSemantic(const CXXRecordDecl *RD, CharUnits Offset) { const ASTRecordLayout &Layout = CGM.getContext().getASTRecordLayout(RD); - // Collect any non-virtual bases. + // Collect any bases. SmallVector Bases; for (const CXXBaseSpecifier &Base : RD->bases()) { - if (!Base.isVirtual() && !Base.getType()->isDependentType()) + if (!Base.getType()->isDependentType()) Bases.push_back(Base.getType()->getAsCXXRecordDecl()); } @@ -2584,7 +2584,7 @@ bool CGMSHLSLRuntime::FindDispatchGridSemantic(const CXXRecordDecl *RD, Layout.getBaseClassOffset(R); }); - // Check (non-virtual) bases + // Check bases in order for (const CXXRecordDecl *Base : Bases) { CharUnits BaseOffset = Offset + Layout.getBaseClassOffset(Base); if (FindDispatchGridSemantic(Base, SDGRec, BaseOffset)) From a2df420f3c594976e654d55c05c2ae74145fc5f7 Mon Sep 17 00:00:00 2001 From: Tim Corringham Date: Thu, 12 Dec 2024 16:52:33 +0000 Subject: [PATCH 3/3] Add template test cases for nested SV_DispatchGrid Added test cases to cover nested SV_DispatchGrid used in records using templates. --- .../workgraph/nested_sv_dispatchgrid.hlsl | 60 +++++++++++++++++-- 1 file changed, 56 insertions(+), 4 deletions(-) diff --git a/tools/clang/test/HLSLFileCheck/hlsl/workgraph/nested_sv_dispatchgrid.hlsl b/tools/clang/test/HLSLFileCheck/hlsl/workgraph/nested_sv_dispatchgrid.hlsl index 5e9cb47cf1..1da45dae1d 100644 --- a/tools/clang/test/HLSLFileCheck/hlsl/workgraph/nested_sv_dispatchgrid.hlsl +++ b/tools/clang/test/HLSLFileCheck/hlsl/workgraph/nested_sv_dispatchgrid.hlsl @@ -4,9 +4,12 @@ // generated in cases where: // node1 - the field with the SV_DispatchGrid semantic is in a nested record // node2 - the field with the SV_DispatchGrid semantic is in a record field -// node3 - the field with the SV_DispatchGrid semantic is inherited from a base record -// node4 - the field with the SV_DispatchGrid semantic is within a nested record inherited from a base record -// node5 - the field with the SV_DispatchGrid semantic is within a base record of a nested record +// node3 - the field with the SV_DispatchGrid semantic is inherited from a base record +// node4 - the field with the SV_DispatchGrid semantic is within a nested record inherited from a base record +// node5 - the field with the SV_DispatchGrid semantic is within a base record of a nested record +// node6 - the field with the SV_DispatchGrid semantic is within a templated base record +// node7 - the field with the SV_DispatchGrid semantic is within a templated base record of a templated record +// node8 - the field with the SV_DispatchGrid semantic has templated type struct Record1 { struct { @@ -19,6 +22,7 @@ struct Record1 { [NodeMaxDispatchGrid(32,16,1)] [NumThreads(32,1,1)] void node1(DispatchNodeInputRecord input) {} +// CHECK: {!"node1" // CHECK: , i32 1, ![[SVDG_1:[0-9]+]] // CHECK: [[SVDG_1]] = !{i32 0, i32 5, i32 3} @@ -37,6 +41,7 @@ struct Record2 { [NodeMaxDispatchGrid(32,16,1)] [NumThreads(32,1,1)] void node2(DispatchNodeInputRecord input) {} +// CHECK: {!"node2" // CHECK: , i32 1, ![[SVDG_2:[0-9]+]] // CHECK: [[SVDG_2]] = !{i32 8, i32 5, i32 2} @@ -49,6 +54,7 @@ struct Record3 : Record2a { [NodeMaxDispatchGrid(32,16,1)] [NumThreads(32,1,1)] void node3(DispatchNodeInputRecord input) {} +// CHECK: {!"node3" // CHECK: , i32 1, ![[SVDG_3:[0-9]+]] // CHECK: [[SVDG_3]] = !{i32 4, i32 5, i32 2} @@ -56,11 +62,12 @@ struct Record4 : Record2 { // SV_DispatchGrid is in a nested field in a base record float f; }; - + [Shader("node")] [NodeMaxDispatchGrid(32,16,1)] [NumThreads(32,1,1)] void node4(DispatchNodeInputRecord input) {} +// CHECK: {!"node4" // CHECK: , i32 1, ![[SVDG_2]] struct Record5 { @@ -74,5 +81,50 @@ struct Record5 { [NodeMaxDispatchGrid(32,16,1)] [NumThreads(32,1,1)] void node5(DispatchNodeInputRecord input) {} +// CHECK: {!"node5" // CHECK: , i32 1, ![[SVDG_5:[0-9]+]] // CHECK: [[SVDG_5]] = !{i32 20, i32 5, i32 2} + +template +struct Base { + T DG : SV_DispatchGrid; +}; + +struct Derived1 : Base { + int4 x; +}; + +[Shader("node")] +[NodeLaunch("broadcasting")] +[NodeMaxDispatchGrid(32,16,1)] +[NumThreads(32,1,1)] +void node6(DispatchNodeInputRecord input) {} +// CHECK: {!"node6" +// CHECK: , i32 1, ![[SVDG_1]] + +template +struct Derived2 : Base { + T Y; +}; + +[Shader("node")] +[NodeLaunch("broadcasting")] +[NodeMaxDispatchGrid(32,16,1)] +[NumThreads(32,1,1)] +void node7(DispatchNodeInputRecord > input) {} +// CHECK: {!"node7" +// CHECK: , i32 1, ![[SVDG_7:[0-9]+]] +// CHECK: [[SVDG_7]] = !{i32 0, i32 5, i32 2} + +template +struct Derived3 { + Derived2 V; +}; + +[Shader("node")] +[NodeLaunch("broadcasting")] +[NodeMaxDispatchGrid(32,16,1)] +[NumThreads(32,1,1)] +void node8(DispatchNodeInputRecord< Derived3 > input) {} +// CHECK: {!"node8" +// CHECK: , i32 1, ![[SVDG_1]]