Skip to content

Commit

Permalink
Fix SPIRV struct reconstruction with bitfields (microsoft#5390)
Browse files Browse the repository at this point in the history
HLSL/SPIR-V structs have some layout differences due to bitfields being
squashed.
The reconstruction logic was using the AST layout, and not the new
SPIR-V layout, meaning we could generate invalid indices during
extraction/construction.

This PR fixes a potential bug with bitfields, while making the code
reusable.

---------

Signed-off-by: Nathan Gauër <[email protected]>
  • Loading branch information
Keenuts authored Jul 7, 2023
1 parent 8711ee6 commit dcf754c
Show file tree
Hide file tree
Showing 3 changed files with 152 additions and 56 deletions.
145 changes: 89 additions & 56 deletions tools/clang/lib/SPIRV/SpirvEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -584,6 +584,54 @@ const StructType *lowerStructType(const SpirvCodeGenOptions &spirvOptions,
return output;
}

// Calls `operation` on for each field in the base and derives class defined by
// `recordType`. The `operation` will receive the AST type linked to the field,
// the SPIRV type linked to the field, and the index of the field in the final
// SPIR-V representation. This index of the field can vary from the AST
// field-index because bitfields are merged into a single field in the SPIR-V
// representation.
//
// If the operation returns false, we stop processing fields.
void forEachSpirvField(
const RecordType *recordType, const StructType *spirvType,
std::function<bool(size_t spirvFieldIndex, const QualType &fieldType,
const StructType::FieldInfo &field)>
operation) {
const auto *cxxDecl = recordType->getAsCXXRecordDecl();
const auto *recordDecl = recordType->getDecl();

// Iterate through the base class (one field per base class).
// Bases cannot be melded into 1 field like bitfields, simple iteration.
uint32_t lastConvertedIndex = 0;
size_t astFieldIndex = 0;
for (const auto &base : cxxDecl->bases()) {
const auto &type = base.getType();
const auto &spirvField = spirvType->getFields()[astFieldIndex];
if (!operation(spirvField.fieldIndex, type, spirvField)) {
return;
}
lastConvertedIndex = spirvField.fieldIndex;
++astFieldIndex;
}

// Iterate through the derived class fields. Field could be merged.
for (const auto *field : recordDecl->fields()) {
const auto &spirvField = spirvType->getFields()[astFieldIndex];
const uint32_t currentFieldIndex = spirvField.fieldIndex;
if (astFieldIndex > 0 && currentFieldIndex == lastConvertedIndex) {
++astFieldIndex;
continue;
}

const auto &type = field->getType();
if (!operation(currentFieldIndex, type, spirvField)) {
return;
}
lastConvertedIndex = currentFieldIndex;
++astFieldIndex;
}
}

} // namespace

SpirvEmitter::SpirvEmitter(CompilerInstance &ci)
Expand Down Expand Up @@ -6410,29 +6458,26 @@ SpirvInstruction *SpirvEmitter::reconstructValue(SpirvInstruction *srcVal,

// Structs
if (const auto *recordType = valType->getAs<RecordType>()) {
uint32_t index = 0;
llvm::SmallVector<SpirvInstruction *, 4> elements;
assert(recordType->isStructureType());

// If the struct inherits from other structs, visit the bases.
const auto *decl = valType->getAsCXXRecordDecl();
for (auto baseIt = decl->bases_begin(), baseIe = decl->bases_end();
baseIt != baseIe; ++baseIt, ++index) {
SpirvInstruction *subSrcVal = spvBuilder.createCompositeExtract(
baseIt->getType(), srcVal, {index}, loc, range);
subSrcVal->setLayoutRule(srcVal->getLayoutRule());
elements.push_back(
reconstructValue(subSrcVal, baseIt->getType(), dstLR, loc, range));
}
LowerTypeVisitor lowerTypeVisitor(astContext, spvContext, spirvOptions);
const StructType *spirvStructType =
lowerStructType(spirvOptions, lowerTypeVisitor, recordType->desugar());

llvm::SmallVector<SpirvInstruction *, 4> elements;
forEachSpirvField(
recordType, spirvStructType,
[&](size_t spirvFieldIndex, const QualType &fieldType,
const auto &field) {
SpirvInstruction *subSrcVal = spvBuilder.createCompositeExtract(
fieldType, srcVal, {static_cast<uint32_t>(spirvFieldIndex)}, loc, range);
subSrcVal->setLayoutRule(srcVal->getLayoutRule());
elements.push_back(
reconstructValue(subSrcVal, fieldType, dstLR, loc, range));

return true;
});

// Go over struct fields.
for (const auto *field : recordType->getDecl()->fields()) {
SpirvInstruction *subSrcVal = spvBuilder.createCompositeExtract(
field->getType(), srcVal, {index}, loc, range);
subSrcVal->setLayoutRule(srcVal->getLayoutRule());
elements.push_back(
reconstructValue(subSrcVal, field->getType(), dstLR, loc, range));
++index;
}
auto *result = spvBuilder.createCompositeConstruct(
valType, elements, srcVal->getSourceLocation(), range);
result->setLayoutRule(dstLR);
Expand Down Expand Up @@ -6947,47 +6992,35 @@ SpirvInstruction *SpirvEmitter::convertVectorToStruct(QualType astStructType,
SourceRange range) {
assert(astStructType->isStructureType());

const auto *structDecl = astStructType->getAsStructureType()->getDecl();
LowerTypeVisitor lowerTypeVisitor(astContext, spvContext, spirvOptions);
const StructType *spirvStructType =
lowerStructType(spirvOptions, lowerTypeVisitor, astStructType);

uint32_t vectorIndex = 0;
uint32_t elemCount = 1;
uint32_t lastConvertedIndex = 0;
llvm::SmallVector<SpirvInstruction *, 4> members;
for (auto field = structDecl->field_begin(); field != structDecl->field_end();
field++) {
// Multiple bitfields can share the same storing type. In such case, we only
// want to append the whole storage once.
const size_t astFieldIndex =
std::distance(structDecl->field_begin(), field);
const uint32_t currentFieldIndex =
spirvStructType->getFields()[astFieldIndex].fieldIndex;
if (astFieldIndex > 0 && currentFieldIndex == lastConvertedIndex) {
continue;
}
lastConvertedIndex = currentFieldIndex;

if (isScalarType(field->getType())) {
members.push_back(spvBuilder.createCompositeExtract(
elemType, vector, {vectorIndex++}, loc, range));
continue;
}

if (isVectorType(field->getType(), nullptr, &elemCount)) {
llvm::SmallVector<uint32_t, 4> indices;
for (uint32_t i = 0; i < elemCount; ++i)
indices.push_back(vectorIndex++);

members.push_back(spvBuilder.createVectorShuffle(
astContext.getExtVectorType(elemType, elemCount), vector, vector,
indices, loc, range));
continue;
}

assert(false && "unhandled type");
}
forEachSpirvField(astStructType->getAs<RecordType>(), spirvStructType,
[&](size_t spirvFieldIndex, const QualType &fieldType,
const auto &field) {
if (isScalarType(fieldType)) {
members.push_back(spvBuilder.createCompositeExtract(
elemType, vector, {vectorIndex++}, loc, range));
return true;
}

if (isVectorType(fieldType, nullptr, &elemCount)) {
llvm::SmallVector<uint32_t, 4> indices;
for (uint32_t i = 0; i < elemCount; ++i)
indices.push_back(vectorIndex++);

members.push_back(spvBuilder.createVectorShuffle(
astContext.getExtVectorType(elemType, elemCount),
vector, vector, indices, loc, range));
return true;
}

assert(false && "unhandled type");
return false;
});

return spvBuilder.createCompositeConstruct(
astStructType, members, vector->getSourceLocation(), range);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
// RUN: %dxc -T cs_6_0 -E main -HV 2021

struct Base {
uint base;
};

struct Derived : Base {
uint a;
uint b : 3;
uint c : 3;
uint d;
};

RWStructuredBuffer<Derived> g_probes : register(u0);

[numthreads(64u, 1u, 1u)]
void main(uint3 dispatchThreadId : SV_DispatchThreadID) {

// CHECK: [[p:%\w+]] = OpVariable %_ptr_Function_Derived_0 Function
Derived p;

// CHECK: [[tmp:%\d+]] = OpAccessChain %_ptr_Function_Base_0 [[p]] %uint_0
// CHECK: [[tmp:%\d+]] = OpAccessChain %_ptr_Function_uint [[tmp]] %int_0
// CHECK: OpStore [[tmp]] %uint_5
p.base = 5;

// CHECK: [[tmp:%\d+]] = OpAccessChain %_ptr_Function_uint [[p]] %int_1
// CHECK: OpStore [[tmp]] %uint_1
p.a = 1;

// CHECK: [[tmp:%\d+]] = OpAccessChain %_ptr_Function_uint [[p]] %int_2
// CHECK: [[value:%\d+]] = OpLoad %uint [[tmp]]
// CHECK: [[value:%\d+]] = OpBitFieldInsert %uint [[value]] %uint_2 %uint_0 %uint_3
// CHECK: OpStore [[tmp]] [[value]]
p.b = 2;

// CHECK: [[tmp:%\d+]] = OpAccessChain %_ptr_Function_uint [[p]] %int_2
// CHECK: [[value:%\d+]] = OpLoad %uint [[tmp]]
// CHECK: [[value:%\d+]] = OpBitFieldInsert %uint [[value]] %uint_3 %uint_3 %uint_3
// CHECK: OpStore [[tmp]] [[value]]
p.c = 3;

// CHECK: [[tmp:%\d+]] = OpAccessChain %_ptr_Function_uint [[p]] %int_3
// CHECK: OpStore [[tmp]] %uint_4
p.d = 4;


// CHECK: [[p:%\d+]] = OpLoad %Derived_0 [[p]]
// CHECK: [[ptr:%\d+]] = OpAccessChain %_ptr_Uniform_Derived %g_probes %int_0 %uint_0
// CHECK: [[tmp:%\d+]] = OpCompositeExtract %Base_0 [[p]] 0
// CHECK: [[tmp:%\d+]] = OpCompositeExtract %uint [[tmp]] 0
// CHECK: [[base:%\d+]] = OpCompositeConstruct %Base [[tmp]]
// CHECK: [[mem1:%\d+]] = OpCompositeExtract %uint [[p]] 1
// CHECK: [[mem2:%\d+]] = OpCompositeExtract %uint [[p]] 2
// CHECK: [[mem3:%\d+]] = OpCompositeExtract %uint [[p]] 3
// CHECK: [[tmp:%\d+]] = OpCompositeConstruct %Derived [[base]] [[mem1]] [[mem2]] [[mem3]]
// CHECK: OpStore [[ptr]] [[tmp]]
g_probes[0] = p;
}

3 changes: 3 additions & 0 deletions tools/clang/unittests/SPIRV/CodeGenSpirvTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -472,6 +472,9 @@ TEST_F(FileTest, OpStructuredBufferAccess) {
TEST_F(FileTest, OpStructuredBufferAccessBitfield) {
runFileTest("op.structured-buffer.access.bitfield.hlsl");
}
TEST_F(FileTest, OpStructuredBufferReconstructBitfield) {
runFileTest("op.structured-buffer.reconstruct.bitfield.hlsl");
}
TEST_F(FileTest, OpRWStructuredBufferAccess) {
runFileTest("op.rw-structured-buffer.access.hlsl");
}
Expand Down

0 comments on commit dcf754c

Please sign in to comment.