Skip to content

Commit

Permalink
Add cast after RWByteAddressBuffer atomic compare exchange (microsoft…
Browse files Browse the repository at this point in the history
…#5295)

* Add cast after RWByteAddressBuffer atomic compare exchange

The type of the RWByteAddressBuffer is always uint in the SPIR-V
representation. This means that the OpAtomicCompareExchange instruction
must have parameters and a result of type uint. If the original code
used an in for the result, then there is a type mismatch when storing
the value. This is fixed by add a cast when appropriate.

Fixes microsoft#4741
  • Loading branch information
s-perron authored Jun 20, 2023
1 parent 0ed5abb commit e8b1fcb
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 2 deletions.
10 changes: 8 additions & 2 deletions tools/clang/lib/SPIRV/SpirvEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3816,14 +3816,20 @@ SpirvInstruction *SpirvEmitter::processRWByteAddressBufferAtomicMethods(

if (isCompareExchange || isCompareStore) {
auto *comparator = doExpr(expr->getArg(1));
auto *originalVal = spvBuilder.createAtomicCompareExchange(
SpirvInstruction *originalVal = spvBuilder.createAtomicCompareExchange(
astContext.UnsignedIntTy, ptr, spv::Scope::Device,
spv::MemorySemanticsMask::MaskNone, spv::MemorySemanticsMask::MaskNone,
doExpr(expr->getArg(2)), comparator, expr->getCallee()->getExprLoc(),
range);
if (isCompareExchange)
if (isCompareExchange) {
auto *resultAddress = expr->getArg(3);
QualType resultType = resultAddress->getType();
if (resultType != astContext.UnsignedIntTy)
originalVal = castToInt(originalVal, astContext.UnsignedIntTy,
resultType, expr->getArg(3)->getLocStart());
spvBuilder.createStore(doExpr(expr->getArg(3)), originalVal,
expr->getArg(3)->getLocStart(), range);
}
} else {
auto *value = doExpr(expr->getArg(1));
SpirvInstruction *originalVal = spvBuilder.createAtomicOp(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ RWByteAddressBuffer myBuffer;
float4 main() : SV_Target
{
uint originalVal;
int originalValAsInt;

// CHECK: [[offset:%\d+]] = OpShiftRightLogical %uint %uint_16 %uint_2
// CHECK-NEXT: [[ptr:%\d+]] = OpAccessChain %_ptr_Uniform_uint %myBuffer %uint_0 [[offset]]
Expand Down Expand Up @@ -102,6 +103,13 @@ float4 main() : SV_Target
// CHECK-NEXT: OpStore %originalVal [[val]]
myBuffer.InterlockedCompareExchange(/*offset=*/16, /*compare_value=*/30, /*value=*/42, originalVal);

// CHECK: [[offset:%\d+]] = OpShiftRightLogical %uint %uint_16 %uint_2
// CHECK-NEXT: [[ptr:%\d+]] = OpAccessChain %_ptr_Uniform_uint %myBuffer %uint_0 [[offset]]
// CHECK-NEXT: [[val:%\d+]] = OpAtomicCompareExchange %uint [[ptr]] %uint_1 %uint_0 %uint_0 %uint_42 %uint_30
// CHECK-NEXT: [[cast:%\d+]] = OpBitcast %int [[val]]
// CHECK-NEXT: OpStore %originalValAsInt [[cast]]
myBuffer.InterlockedCompareExchange(/*offset=*/16, /*compare_value=*/30, /*value=*/42, originalValAsInt);

// CHECK: [[offset:%\d+]] = OpShiftRightLogical %uint %uint_16 %uint_2
// CHECK-NEXT: [[ptr:%\d+]] = OpAccessChain %_ptr_Uniform_uint %myBuffer %uint_0 [[offset]]
// CHECK-NEXT: [[val:%\d+]] = OpAtomicCompareExchange %uint [[ptr]] %uint_1 %uint_0 %uint_0 %uint_42 %uint_30
Expand Down

0 comments on commit e8b1fcb

Please sign in to comment.