From e8b1fcb05fcae40652e32fbab75a905a7ad208bf Mon Sep 17 00:00:00 2001 From: Steven Perron Date: Tue, 20 Jun 2023 08:19:58 -0700 Subject: [PATCH] Add cast after RWByteAddressBuffer atomic compare exchange (#5295) * Add cast after RWByteAddressBuffer atomic compare exchange The type of the RWByteAddressBuffer is always uint in the SPIR-V representation. This means that the OpAtomicCompareExchange instruction must have parameters and a result of type uint. If the original code used an in for the result, then there is a type mismatch when storing the value. This is fixed by add a cast when appropriate. Fixes #4741 --- tools/clang/lib/SPIRV/SpirvEmitter.cpp | 10 ++++++++-- .../method.rw-byte-address-buffer.atomic.hlsl | 8 ++++++++ 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/tools/clang/lib/SPIRV/SpirvEmitter.cpp b/tools/clang/lib/SPIRV/SpirvEmitter.cpp index d9a246325e..a617b09a56 100644 --- a/tools/clang/lib/SPIRV/SpirvEmitter.cpp +++ b/tools/clang/lib/SPIRV/SpirvEmitter.cpp @@ -3816,14 +3816,20 @@ SpirvInstruction *SpirvEmitter::processRWByteAddressBufferAtomicMethods( if (isCompareExchange || isCompareStore) { auto *comparator = doExpr(expr->getArg(1)); - auto *originalVal = spvBuilder.createAtomicCompareExchange( + SpirvInstruction *originalVal = spvBuilder.createAtomicCompareExchange( astContext.UnsignedIntTy, ptr, spv::Scope::Device, spv::MemorySemanticsMask::MaskNone, spv::MemorySemanticsMask::MaskNone, doExpr(expr->getArg(2)), comparator, expr->getCallee()->getExprLoc(), range); - if (isCompareExchange) + if (isCompareExchange) { + auto *resultAddress = expr->getArg(3); + QualType resultType = resultAddress->getType(); + if (resultType != astContext.UnsignedIntTy) + originalVal = castToInt(originalVal, astContext.UnsignedIntTy, + resultType, expr->getArg(3)->getLocStart()); spvBuilder.createStore(doExpr(expr->getArg(3)), originalVal, expr->getArg(3)->getLocStart(), range); + } } else { auto *value = doExpr(expr->getArg(1)); SpirvInstruction *originalVal = spvBuilder.createAtomicOp( diff --git a/tools/clang/test/CodeGenSPIRV/method.rw-byte-address-buffer.atomic.hlsl b/tools/clang/test/CodeGenSPIRV/method.rw-byte-address-buffer.atomic.hlsl index da6242a0be..3181c19e76 100644 --- a/tools/clang/test/CodeGenSPIRV/method.rw-byte-address-buffer.atomic.hlsl +++ b/tools/clang/test/CodeGenSPIRV/method.rw-byte-address-buffer.atomic.hlsl @@ -8,6 +8,7 @@ RWByteAddressBuffer myBuffer; float4 main() : SV_Target { uint originalVal; + int originalValAsInt; // CHECK: [[offset:%\d+]] = OpShiftRightLogical %uint %uint_16 %uint_2 // CHECK-NEXT: [[ptr:%\d+]] = OpAccessChain %_ptr_Uniform_uint %myBuffer %uint_0 [[offset]] @@ -102,6 +103,13 @@ float4 main() : SV_Target // CHECK-NEXT: OpStore %originalVal [[val]] myBuffer.InterlockedCompareExchange(/*offset=*/16, /*compare_value=*/30, /*value=*/42, originalVal); +// CHECK: [[offset:%\d+]] = OpShiftRightLogical %uint %uint_16 %uint_2 +// CHECK-NEXT: [[ptr:%\d+]] = OpAccessChain %_ptr_Uniform_uint %myBuffer %uint_0 [[offset]] +// CHECK-NEXT: [[val:%\d+]] = OpAtomicCompareExchange %uint [[ptr]] %uint_1 %uint_0 %uint_0 %uint_42 %uint_30 +// CHECK-NEXT: [[cast:%\d+]] = OpBitcast %int [[val]] +// CHECK-NEXT: OpStore %originalValAsInt [[cast]] + myBuffer.InterlockedCompareExchange(/*offset=*/16, /*compare_value=*/30, /*value=*/42, originalValAsInt); + // CHECK: [[offset:%\d+]] = OpShiftRightLogical %uint %uint_16 %uint_2 // CHECK-NEXT: [[ptr:%\d+]] = OpAccessChain %_ptr_Uniform_uint %myBuffer %uint_0 [[offset]] // CHECK-NEXT: [[val:%\d+]] = OpAtomicCompareExchange %uint [[ptr]] %uint_1 %uint_0 %uint_0 %uint_42 %uint_30