diff --git a/tools/clang/lib/SPIRV/SpirvEmitter.cpp b/tools/clang/lib/SPIRV/SpirvEmitter.cpp index 1b456b10ca..d68584bf6e 100644 --- a/tools/clang/lib/SPIRV/SpirvEmitter.cpp +++ b/tools/clang/lib/SPIRV/SpirvEmitter.cpp @@ -584,6 +584,54 @@ const StructType *lowerStructType(const SpirvCodeGenOptions &spirvOptions, return output; } +// Calls `operation` on for each field in the base and derives class defined by +// `recordType`. The `operation` will receive the AST type linked to the field, +// the SPIRV type linked to the field, and the index of the field in the final +// SPIR-V representation. This index of the field can vary from the AST +// field-index because bitfields are merged into a single field in the SPIR-V +// representation. +// +// If the operation returns false, we stop processing fields. +void forEachSpirvField( + const RecordType *recordType, const StructType *spirvType, + std::function + operation) { + const auto *cxxDecl = recordType->getAsCXXRecordDecl(); + const auto *recordDecl = recordType->getDecl(); + + // Iterate through the base class (one field per base class). + // Bases cannot be melded into 1 field like bitfields, simple iteration. + uint32_t lastConvertedIndex = 0; + size_t astFieldIndex = 0; + for (const auto &base : cxxDecl->bases()) { + const auto &type = base.getType(); + const auto &spirvField = spirvType->getFields()[astFieldIndex]; + if (!operation(spirvField.fieldIndex, type, spirvField)) { + return; + } + lastConvertedIndex = spirvField.fieldIndex; + ++astFieldIndex; + } + + // Iterate through the derived class fields. Field could be merged. + for (const auto *field : recordDecl->fields()) { + const auto &spirvField = spirvType->getFields()[astFieldIndex]; + const uint32_t currentFieldIndex = spirvField.fieldIndex; + if (astFieldIndex > 0 && currentFieldIndex == lastConvertedIndex) { + ++astFieldIndex; + continue; + } + + const auto &type = field->getType(); + if (!operation(currentFieldIndex, type, spirvField)) { + return; + } + lastConvertedIndex = currentFieldIndex; + ++astFieldIndex; + } +} + } // namespace SpirvEmitter::SpirvEmitter(CompilerInstance &ci) @@ -6410,29 +6458,26 @@ SpirvInstruction *SpirvEmitter::reconstructValue(SpirvInstruction *srcVal, // Structs if (const auto *recordType = valType->getAs()) { - uint32_t index = 0; - llvm::SmallVector elements; + assert(recordType->isStructureType()); - // If the struct inherits from other structs, visit the bases. - const auto *decl = valType->getAsCXXRecordDecl(); - for (auto baseIt = decl->bases_begin(), baseIe = decl->bases_end(); - baseIt != baseIe; ++baseIt, ++index) { - SpirvInstruction *subSrcVal = spvBuilder.createCompositeExtract( - baseIt->getType(), srcVal, {index}, loc, range); - subSrcVal->setLayoutRule(srcVal->getLayoutRule()); - elements.push_back( - reconstructValue(subSrcVal, baseIt->getType(), dstLR, loc, range)); - } + LowerTypeVisitor lowerTypeVisitor(astContext, spvContext, spirvOptions); + const StructType *spirvStructType = + lowerStructType(spirvOptions, lowerTypeVisitor, recordType->desugar()); + + llvm::SmallVector elements; + forEachSpirvField( + recordType, spirvStructType, + [&](size_t spirvFieldIndex, const QualType &fieldType, + const auto &field) { + SpirvInstruction *subSrcVal = spvBuilder.createCompositeExtract( + fieldType, srcVal, {static_cast(spirvFieldIndex)}, loc, range); + subSrcVal->setLayoutRule(srcVal->getLayoutRule()); + elements.push_back( + reconstructValue(subSrcVal, fieldType, dstLR, loc, range)); + + return true; + }); - // Go over struct fields. - for (const auto *field : recordType->getDecl()->fields()) { - SpirvInstruction *subSrcVal = spvBuilder.createCompositeExtract( - field->getType(), srcVal, {index}, loc, range); - subSrcVal->setLayoutRule(srcVal->getLayoutRule()); - elements.push_back( - reconstructValue(subSrcVal, field->getType(), dstLR, loc, range)); - ++index; - } auto *result = spvBuilder.createCompositeConstruct( valType, elements, srcVal->getSourceLocation(), range); result->setLayoutRule(dstLR); @@ -6947,47 +6992,35 @@ SpirvInstruction *SpirvEmitter::convertVectorToStruct(QualType astStructType, SourceRange range) { assert(astStructType->isStructureType()); - const auto *structDecl = astStructType->getAsStructureType()->getDecl(); LowerTypeVisitor lowerTypeVisitor(astContext, spvContext, spirvOptions); const StructType *spirvStructType = lowerStructType(spirvOptions, lowerTypeVisitor, astStructType); - uint32_t vectorIndex = 0; uint32_t elemCount = 1; - uint32_t lastConvertedIndex = 0; llvm::SmallVector members; - for (auto field = structDecl->field_begin(); field != structDecl->field_end(); - field++) { - // Multiple bitfields can share the same storing type. In such case, we only - // want to append the whole storage once. - const size_t astFieldIndex = - std::distance(structDecl->field_begin(), field); - const uint32_t currentFieldIndex = - spirvStructType->getFields()[astFieldIndex].fieldIndex; - if (astFieldIndex > 0 && currentFieldIndex == lastConvertedIndex) { - continue; - } - lastConvertedIndex = currentFieldIndex; - - if (isScalarType(field->getType())) { - members.push_back(spvBuilder.createCompositeExtract( - elemType, vector, {vectorIndex++}, loc, range)); - continue; - } - - if (isVectorType(field->getType(), nullptr, &elemCount)) { - llvm::SmallVector indices; - for (uint32_t i = 0; i < elemCount; ++i) - indices.push_back(vectorIndex++); - - members.push_back(spvBuilder.createVectorShuffle( - astContext.getExtVectorType(elemType, elemCount), vector, vector, - indices, loc, range)); - continue; - } - - assert(false && "unhandled type"); - } + forEachSpirvField(astStructType->getAs(), spirvStructType, + [&](size_t spirvFieldIndex, const QualType &fieldType, + const auto &field) { + if (isScalarType(fieldType)) { + members.push_back(spvBuilder.createCompositeExtract( + elemType, vector, {vectorIndex++}, loc, range)); + return true; + } + + if (isVectorType(fieldType, nullptr, &elemCount)) { + llvm::SmallVector indices; + for (uint32_t i = 0; i < elemCount; ++i) + indices.push_back(vectorIndex++); + + members.push_back(spvBuilder.createVectorShuffle( + astContext.getExtVectorType(elemType, elemCount), + vector, vector, indices, loc, range)); + return true; + } + + assert(false && "unhandled type"); + return false; + }); return spvBuilder.createCompositeConstruct( astStructType, members, vector->getSourceLocation(), range); diff --git a/tools/clang/test/CodeGenSPIRV/op.structured-buffer.reconstruct.bitfield.hlsl b/tools/clang/test/CodeGenSPIRV/op.structured-buffer.reconstruct.bitfield.hlsl new file mode 100644 index 0000000000..a0b55b4080 --- /dev/null +++ b/tools/clang/test/CodeGenSPIRV/op.structured-buffer.reconstruct.bitfield.hlsl @@ -0,0 +1,60 @@ +// RUN: %dxc -T cs_6_0 -E main -HV 2021 + +struct Base { + uint base; +}; + +struct Derived : Base { + uint a; + uint b : 3; + uint c : 3; + uint d; +}; + +RWStructuredBuffer g_probes : register(u0); + +[numthreads(64u, 1u, 1u)] +void main(uint3 dispatchThreadId : SV_DispatchThreadID) { + +// CHECK: [[p:%\w+]] = OpVariable %_ptr_Function_Derived_0 Function + Derived p; + +// CHECK: [[tmp:%\d+]] = OpAccessChain %_ptr_Function_Base_0 [[p]] %uint_0 +// CHECK: [[tmp:%\d+]] = OpAccessChain %_ptr_Function_uint [[tmp]] %int_0 +// CHECK: OpStore [[tmp]] %uint_5 + p.base = 5; + +// CHECK: [[tmp:%\d+]] = OpAccessChain %_ptr_Function_uint [[p]] %int_1 +// CHECK: OpStore [[tmp]] %uint_1 + p.a = 1; + +// CHECK: [[tmp:%\d+]] = OpAccessChain %_ptr_Function_uint [[p]] %int_2 +// CHECK: [[value:%\d+]] = OpLoad %uint [[tmp]] +// CHECK: [[value:%\d+]] = OpBitFieldInsert %uint [[value]] %uint_2 %uint_0 %uint_3 +// CHECK: OpStore [[tmp]] [[value]] + p.b = 2; + +// CHECK: [[tmp:%\d+]] = OpAccessChain %_ptr_Function_uint [[p]] %int_2 +// CHECK: [[value:%\d+]] = OpLoad %uint [[tmp]] +// CHECK: [[value:%\d+]] = OpBitFieldInsert %uint [[value]] %uint_3 %uint_3 %uint_3 +// CHECK: OpStore [[tmp]] [[value]] + p.c = 3; + +// CHECK: [[tmp:%\d+]] = OpAccessChain %_ptr_Function_uint [[p]] %int_3 +// CHECK: OpStore [[tmp]] %uint_4 + p.d = 4; + + +// CHECK: [[p:%\d+]] = OpLoad %Derived_0 [[p]] +// CHECK: [[ptr:%\d+]] = OpAccessChain %_ptr_Uniform_Derived %g_probes %int_0 %uint_0 +// CHECK: [[tmp:%\d+]] = OpCompositeExtract %Base_0 [[p]] 0 +// CHECK: [[tmp:%\d+]] = OpCompositeExtract %uint [[tmp]] 0 +// CHECK: [[base:%\d+]] = OpCompositeConstruct %Base [[tmp]] +// CHECK: [[mem1:%\d+]] = OpCompositeExtract %uint [[p]] 1 +// CHECK: [[mem2:%\d+]] = OpCompositeExtract %uint [[p]] 2 +// CHECK: [[mem3:%\d+]] = OpCompositeExtract %uint [[p]] 3 +// CHECK: [[tmp:%\d+]] = OpCompositeConstruct %Derived [[base]] [[mem1]] [[mem2]] [[mem3]] +// CHECK: OpStore [[ptr]] [[tmp]] + g_probes[0] = p; +} + diff --git a/tools/clang/unittests/SPIRV/CodeGenSpirvTest.cpp b/tools/clang/unittests/SPIRV/CodeGenSpirvTest.cpp index 392ce2f947..9d591a8bf3 100644 --- a/tools/clang/unittests/SPIRV/CodeGenSpirvTest.cpp +++ b/tools/clang/unittests/SPIRV/CodeGenSpirvTest.cpp @@ -472,6 +472,9 @@ TEST_F(FileTest, OpStructuredBufferAccess) { TEST_F(FileTest, OpStructuredBufferAccessBitfield) { runFileTest("op.structured-buffer.access.bitfield.hlsl"); } +TEST_F(FileTest, OpStructuredBufferReconstructBitfield) { + runFileTest("op.structured-buffer.reconstruct.bitfield.hlsl"); +} TEST_F(FileTest, OpRWStructuredBufferAccess) { runFileTest("op.rw-structured-buffer.access.hlsl"); }