diff --git a/tools/clang/lib/SPIRV/SpirvEmitter.cpp b/tools/clang/lib/SPIRV/SpirvEmitter.cpp index d9a246325e..a617b09a56 100644 --- a/tools/clang/lib/SPIRV/SpirvEmitter.cpp +++ b/tools/clang/lib/SPIRV/SpirvEmitter.cpp @@ -3816,14 +3816,20 @@ SpirvInstruction *SpirvEmitter::processRWByteAddressBufferAtomicMethods( if (isCompareExchange || isCompareStore) { auto *comparator = doExpr(expr->getArg(1)); - auto *originalVal = spvBuilder.createAtomicCompareExchange( + SpirvInstruction *originalVal = spvBuilder.createAtomicCompareExchange( astContext.UnsignedIntTy, ptr, spv::Scope::Device, spv::MemorySemanticsMask::MaskNone, spv::MemorySemanticsMask::MaskNone, doExpr(expr->getArg(2)), comparator, expr->getCallee()->getExprLoc(), range); - if (isCompareExchange) + if (isCompareExchange) { + auto *resultAddress = expr->getArg(3); + QualType resultType = resultAddress->getType(); + if (resultType != astContext.UnsignedIntTy) + originalVal = castToInt(originalVal, astContext.UnsignedIntTy, + resultType, expr->getArg(3)->getLocStart()); spvBuilder.createStore(doExpr(expr->getArg(3)), originalVal, expr->getArg(3)->getLocStart(), range); + } } else { auto *value = doExpr(expr->getArg(1)); SpirvInstruction *originalVal = spvBuilder.createAtomicOp( diff --git a/tools/clang/test/CodeGenSPIRV/method.rw-byte-address-buffer.atomic.hlsl b/tools/clang/test/CodeGenSPIRV/method.rw-byte-address-buffer.atomic.hlsl index da6242a0be..3181c19e76 100644 --- a/tools/clang/test/CodeGenSPIRV/method.rw-byte-address-buffer.atomic.hlsl +++ b/tools/clang/test/CodeGenSPIRV/method.rw-byte-address-buffer.atomic.hlsl @@ -8,6 +8,7 @@ RWByteAddressBuffer myBuffer; float4 main() : SV_Target { uint originalVal; + int originalValAsInt; // CHECK: [[offset:%\d+]] = OpShiftRightLogical %uint %uint_16 %uint_2 // CHECK-NEXT: [[ptr:%\d+]] = OpAccessChain %_ptr_Uniform_uint %myBuffer %uint_0 [[offset]] @@ -102,6 +103,13 @@ float4 main() : SV_Target // CHECK-NEXT: OpStore %originalVal [[val]] myBuffer.InterlockedCompareExchange(/*offset=*/16, /*compare_value=*/30, /*value=*/42, originalVal); +// CHECK: [[offset:%\d+]] = OpShiftRightLogical %uint %uint_16 %uint_2 +// CHECK-NEXT: [[ptr:%\d+]] = OpAccessChain %_ptr_Uniform_uint %myBuffer %uint_0 [[offset]] +// CHECK-NEXT: [[val:%\d+]] = OpAtomicCompareExchange %uint [[ptr]] %uint_1 %uint_0 %uint_0 %uint_42 %uint_30 +// CHECK-NEXT: [[cast:%\d+]] = OpBitcast %int [[val]] +// CHECK-NEXT: OpStore %originalValAsInt [[cast]] + myBuffer.InterlockedCompareExchange(/*offset=*/16, /*compare_value=*/30, /*value=*/42, originalValAsInt); + // CHECK: [[offset:%\d+]] = OpShiftRightLogical %uint %uint_16 %uint_2 // CHECK-NEXT: [[ptr:%\d+]] = OpAccessChain %_ptr_Uniform_uint %myBuffer %uint_0 [[offset]] // CHECK-NEXT: [[val:%\d+]] = OpAtomicCompareExchange %uint [[ptr]] %uint_1 %uint_0 %uint_0 %uint_42 %uint_30