Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support SV_DispatchGrid semantic in a nested record #6931

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 75 additions & 55 deletions tools/clang/lib/CodeGen/CGHLSLMS.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,9 @@ class CGMSHLSLRuntime : public CGHLSLRuntime {
llvm::Value *DestPtr,
clang::QualType DestTy) override;
void AddHLSLFunctionInfo(llvm::Function *, const FunctionDecl *FD) override;
bool FindDispatchGridSemantic(const CXXRecordDecl *RD,
hlsl::SVDispatchGrid &SDGRec,
CharUnits Offset = CharUnits());
void AddHLSLNodeRecordTypeInfo(const clang::ParmVarDecl *parmDecl,
hlsl::NodeIOProperties &node);
void EmitHLSLFunctionProlog(llvm::Function *,
Expand Down Expand Up @@ -2558,6 +2561,75 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
m_ScopeMap[F] = ScopeInfo(F, FD->getLocation());
}

// Find the input node record field with the SV_DispatchGrid semantic.
// We have already diagnosed any error conditions in Sema, so we
// expect valid size and types, and use the first occurance found.
// We return true if we have populated the SV_DispatchGrid values.
bool CGMSHLSLRuntime::FindDispatchGridSemantic(const CXXRecordDecl *RD,
hlsl::SVDispatchGrid &SDGRec,
CharUnits Offset) {
const ASTRecordLayout &Layout = CGM.getContext().getASTRecordLayout(RD);

// Collect any bases.
SmallVector<const CXXRecordDecl *, 4> Bases;
for (const CXXBaseSpecifier &Base : RD->bases()) {
if (!Base.getType()->isDependentType())
Bases.push_back(Base.getType()->getAsCXXRecordDecl());
}

// Sort bases by offset.
std::stable_sort(Bases.begin(), Bases.end(),
[&](const CXXRecordDecl *L, const CXXRecordDecl *R) {
return Layout.getBaseClassOffset(L) <
Layout.getBaseClassOffset(R);
});

// Check bases in order
for (const CXXRecordDecl *Base : Bases) {
CharUnits BaseOffset = Offset + Layout.getBaseClassOffset(Base);
if (FindDispatchGridSemantic(Base, SDGRec, BaseOffset))
return true;
}

// Check each field in this record.
for (FieldDecl *Field : RD->fields()) {
uint64_t FieldNo = Field->getFieldIndex();
CharUnits FieldOffset = Offset + CGM.getContext().toCharUnitsFromBits(
Layout.getFieldOffset(FieldNo));

// If this field is a record check its fields
if (const CXXRecordDecl *D = Field->getType()->getAsCXXRecordDecl()) {
if (FindDispatchGridSemantic(D, SDGRec, FieldOffset))
return true;
}
// Otherwise check this field for the SV_DispatchGrid semantic annotation
for (const hlsl::UnusualAnnotation *it : Field->getUnusualAnnotations()) {
if (it->getKind() == hlsl::UnusualAnnotation::UA_SemanticDecl) {
const hlsl::SemanticDecl *sd = cast<hlsl::SemanticDecl>(it);
if (sd->SemanticName.equals("SV_DispatchGrid")) {
const llvm::Type *FTy = CGM.getTypes().ConvertType(Field->getType());
const llvm::Type *ElTy = FTy;
SDGRec.NumComponents = 1;
SDGRec.ByteOffset = (unsigned)FieldOffset.getQuantity();
if (const llvm::VectorType *VT = dyn_cast<llvm::VectorType>(FTy)) {
SDGRec.NumComponents = VT->getNumElements();
ElTy = VT->getElementType();
} else if (const llvm::ArrayType *AT =
dyn_cast<llvm::ArrayType>(FTy)) {
SDGRec.NumComponents = AT->getNumElements();
ElTy = AT->getElementType();
}
SDGRec.ComponentType = (ElTy->getIntegerBitWidth() == 16)
? DXIL::ComponentType::U16
: DXIL::ComponentType::U32;
return true;
}
}
}
}
return false;
}

void CGMSHLSLRuntime::AddHLSLNodeRecordTypeInfo(
const clang::ParmVarDecl *parmDecl, hlsl::NodeIOProperties &node) {
clang::QualType paramTy = parmDecl->getType().getCanonicalType();
Expand All @@ -2575,7 +2647,6 @@ void CGMSHLSLRuntime::AddHLSLNodeRecordTypeInfo(
DiagnosticsEngine &Diags = CGM.getDiags();
auto &Rec = TemplateArgs.get(0);
clang::QualType RecType = Rec.getAsType();
llvm::Type *Type = CGM.getTypes().ConvertType(RecType);
CXXRecordDecl *RD = RecType->getAsCXXRecordDecl();

// Get the TrackRWInputSharing flag from the record attribute
Expand All @@ -2595,63 +2666,12 @@ void CGMSHLSLRuntime::AddHLSLNodeRecordTypeInfo(

// Ex: For DispatchNodeInputRecord<MY_RECORD>, set size =
// size(MY_RECORD), alignment = alignof(MY_RECORD)
llvm::Type *Type = CGM.getTypes().ConvertType(RecType);
node.RecordType.size = CGM.getDataLayout().getTypeAllocSize(Type);
node.RecordType.alignment =
CGM.getDataLayout().getABITypeAlignment(Type);
// Iterate over fields of the MY_RECORD(example) struct
for (auto fieldDecl : RD->fields()) {
// Check if any of the fields have a semantic annotation =
// SV_DispatchGrid
for (const hlsl::UnusualAnnotation *it :
fieldDecl->getUnusualAnnotations()) {
if (it->getKind() == hlsl::UnusualAnnotation::UA_SemanticDecl) {
const hlsl::SemanticDecl *sd = cast<hlsl::SemanticDecl>(it);
// if we find a field with SV_DispatchGrid, fill out the
// SV_DispatchGrid member with byteoffset of the field,
// NumComponents (3 for uint3 etc) and U32 vs U16 types, which are
// the only types allowed
if (sd->SemanticName.equals("SV_DispatchGrid")) {
clang::QualType FT = fieldDecl->getType();
auto &DL = CGM.getDataLayout();
auto &SDGRec = node.RecordType.SV_DispatchGrid;

DXASSERT_NOMSG(SDGRec.NumComponents == 0);

unsigned fieldIdx = fieldDecl->getFieldIndex();
if (StructType *ST = dyn_cast<StructType>(Type)) {
SDGRec.ByteOffset =
DL.getStructLayout(ST)->getElementOffset(fieldIdx);
}
const llvm::Type *lTy = CGM.getTypes().ConvertType(FT);
if (const llvm::VectorType *VT =
dyn_cast<llvm::VectorType>(lTy)) {
DXASSERT(VT->getElementType()->isIntegerTy(), "invalid type");
SDGRec.NumComponents = VT->getNumElements();
SDGRec.ComponentType =
(VT->getElementType()->getIntegerBitWidth() == 16)
? DXIL::ComponentType::U16
: DXIL::ComponentType::U32;
} else if (const llvm::ArrayType *AT =
dyn_cast<llvm::ArrayType>(lTy)) {
DXASSERT(AT->getElementType()->isIntegerTy(), "invalid type");
DXASSERT_NOMSG(AT->getNumElements() <= 3);
SDGRec.NumComponents = AT->getNumElements();
SDGRec.ComponentType =
(AT->getElementType()->getIntegerBitWidth() == 16)
? DXIL::ComponentType::U16
: DXIL::ComponentType::U32;
} else {
// Scalar U16 or U32
DXASSERT(lTy->isIntegerTy(), "invalid type");
SDGRec.NumComponents = 1;
SDGRec.ComponentType = (lTy->getIntegerBitWidth() == 16)
? DXIL::ComponentType::U16
: DXIL::ComponentType::U32;
}
}
}
}
}

FindDispatchGridSemantic(RD, node.RecordType.SV_DispatchGrid);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
// RUN: %dxc -T lib_6_8 %s | FileCheck %s

// Check that the SV_DispatchGrid DXIL metadata for a node input record is
// generated in cases where:
// node1 - the field with the SV_DispatchGrid semantic is in a nested record
// node2 - the field with the SV_DispatchGrid semantic is in a record field
// node3 - the field with the SV_DispatchGrid semantic is inherited from a base record
// node4 - the field with the SV_DispatchGrid semantic is within a nested record inherited from a base record
// node5 - the field with the SV_DispatchGrid semantic is within a base record of a nested record
// node6 - the field with the SV_DispatchGrid semantic is within a templated base record
// node7 - the field with the SV_DispatchGrid semantic is within a templated base record of a templated record
// node8 - the field with the SV_DispatchGrid semantic has templated type

struct Record1 {
struct {
// SV_DispatchGrid is within a nested record
uint3 grid : SV_DispatchGrid;
};
};

[Shader("node")]
[NodeMaxDispatchGrid(32,16,1)]
[NumThreads(32,1,1)]
void node1(DispatchNodeInputRecord<Record1> input) {}
// CHECK: {!"node1"
// CHECK: , i32 1, ![[SVDG_1:[0-9]+]]
// CHECK: [[SVDG_1]] = !{i32 0, i32 5, i32 3}

struct Record2a {
uint u;
uint2 grid : SV_DispatchGrid;
};

struct Record2 {
uint a;
// SV_DispatchGrid is within a record field
Record2a b;
};

[Shader("node")]
[NodeMaxDispatchGrid(32,16,1)]
[NumThreads(32,1,1)]
void node2(DispatchNodeInputRecord<Record2> input) {}
// CHECK: {!"node2"
// CHECK: , i32 1, ![[SVDG_2:[0-9]+]]
// CHECK: [[SVDG_2]] = !{i32 8, i32 5, i32 2}

struct Record3 : Record2a {
// SV_DispatchGrid is inherited
uint4 n;
};

[Shader("node")]
[NodeMaxDispatchGrid(32,16,1)]
[NumThreads(32,1,1)]
void node3(DispatchNodeInputRecord<Record3> input) {}
// CHECK: {!"node3"
// CHECK: , i32 1, ![[SVDG_3:[0-9]+]]
// CHECK: [[SVDG_3]] = !{i32 4, i32 5, i32 2}

struct Record4 : Record2 {
// SV_DispatchGrid is in a nested field in a base record
float f;
};

[Shader("node")]
[NodeMaxDispatchGrid(32,16,1)]
[NumThreads(32,1,1)]
void node4(DispatchNodeInputRecord<Record4> input) {}
// CHECK: {!"node4"
// CHECK: , i32 1, ![[SVDG_2]]

struct Record5 {
uint4 x;
// SV_DispatchGrid is in a base record of a record field
Record3 r;
};

[Shader("node")]
[NodeLaunch("broadcasting")]
[NodeMaxDispatchGrid(32,16,1)]
[NumThreads(32,1,1)]
void node5(DispatchNodeInputRecord<Record5> input) {}
// CHECK: {!"node5"
// CHECK: , i32 1, ![[SVDG_5:[0-9]+]]
// CHECK: [[SVDG_5]] = !{i32 20, i32 5, i32 2}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about some test cases with templates?

Something like:

template <typename T>
struct Base {
  T DG : SV_DispatchGrid;
};

struct Derived1 : Base<uint3> {
  int4 x;
};

[Shader("node")]
[NodeLaunch("broadcasting")]
[NodeMaxDispatchGrid(32,16,1)]
[NumThreads(32,1,1)]
void node6(DispatchNodeInputRecord<Derived1 > input) {}

template <typename T>
struct Derived2 : Base<T> {
  T Y;
};

[Shader("node")]
[NodeLaunch("broadcasting")]
[NodeMaxDispatchGrid(32,16,1)]
[NumThreads(32,1,1)]
void node7(DispatchNodeInputRecord<Derived2<uint2> > input) {}


template <typename T>
struct Derived3 {
  Derived2<T> V;
};

[Shader("node")]
[NodeLaunch("broadcasting")]
[NodeMaxDispatchGrid(32,16,1)]
[NumThreads(32,1,1)]
void node8(DispatchNodeInputRecord< Derived3 <uint3> > input) {}

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great idea! I've updated the test to include these cases.


template <typename T>
struct Base {
T DG : SV_DispatchGrid;
};

struct Derived1 : Base<uint3> {
int4 x;
};

[Shader("node")]
[NodeLaunch("broadcasting")]
[NodeMaxDispatchGrid(32,16,1)]
[NumThreads(32,1,1)]
void node6(DispatchNodeInputRecord<Derived1 > input) {}
// CHECK: {!"node6"
// CHECK: , i32 1, ![[SVDG_1]]

template <typename T>
struct Derived2 : Base<T> {
T Y;
};

[Shader("node")]
[NodeLaunch("broadcasting")]
[NodeMaxDispatchGrid(32,16,1)]
[NumThreads(32,1,1)]
void node7(DispatchNodeInputRecord<Derived2<uint2> > input) {}
// CHECK: {!"node7"
// CHECK: , i32 1, ![[SVDG_7:[0-9]+]]
// CHECK: [[SVDG_7]] = !{i32 0, i32 5, i32 2}

template <typename T>
struct Derived3 {
Derived2<T> V;
};

[Shader("node")]
[NodeLaunch("broadcasting")]
[NodeMaxDispatchGrid(32,16,1)]
[NumThreads(32,1,1)]
void node8(DispatchNodeInputRecord< Derived3 <uint3> > input) {}
// CHECK: {!"node8"
// CHECK: , i32 1, ![[SVDG_1]]
Loading