diff --git a/source/slang/slang-ir-lower-buffer-element-type.cpp b/source/slang/slang-ir-lower-buffer-element-type.cpp index e9fbfc0d14..3605967410 100644 --- a/source/slang/slang-ir-lower-buffer-element-type.cpp +++ b/source/slang/slang-ir-lower-buffer-element-type.cpp @@ -9,6 +9,8 @@ namespace Slang { struct LoweredElementTypeContext { + static const IRIntegerValue kMaxArraySizeToUnroll = 32; + struct LoweredElementTypeInfo { IRType* originalType; @@ -161,17 +163,42 @@ namespace Slang auto packedParam = builder.emitParam(structType); auto packedArray = builder.emitFieldExtract(innerArrayType, packedParam, dataKey); auto count = getIntVal(arrayType->getElementCount()); - List args; - args.setCount((Index)count); - for (IRIntegerValue ii = 0; ii < count; ++ii) + IRInst* result = nullptr; + if (count <= kMaxArraySizeToUnroll) + { + // If the array is small enough, just process each element directly. + List args; + args.setCount((Index)count); + for (IRIntegerValue ii = 0; ii < count; ++ii) + { + auto packedElement = builder.emitElementExtract(packedArray, ii); + auto originalElement = innerTypeInfo.convertLoweredToOriginal + ? builder.emitCallInst(innerTypeInfo.originalType, innerTypeInfo.convertLoweredToOriginal, 1, &packedElement) + : packedElement; + args[(Index)ii] = originalElement; + } + result = builder.emitMakeArray(arrayType, (UInt)args.getCount(), args.getBuffer()); + + } + else { - auto packedElement = builder.emitElementExtract(packedArray, ii); + // The general case for large arrays is to emit a loop through the elements. + IRVar* resultVar = builder.emitVar(arrayType); + IRBlock* loopBodyBlock; + IRBlock* loopBreakBlock; + auto loopParam = emitLoopBlocks(&builder, builder.getIntValue(builder.getIntType(), 0), builder.getIntValue(builder.getIntType(), count), + loopBodyBlock, loopBreakBlock); + + builder.setInsertBefore(loopBodyBlock->getFirstOrdinaryInst()); + auto packedElement = builder.emitElementExtract(packedArray, loopParam); auto originalElement = innerTypeInfo.convertLoweredToOriginal ? builder.emitCallInst(innerTypeInfo.originalType, innerTypeInfo.convertLoweredToOriginal, 1, &packedElement) : packedElement; - args[(Index)ii] = originalElement; + auto varPtr = builder.emitElementAddress(resultVar, loopParam); + builder.emitStore(varPtr, originalElement); + builder.setInsertInto(loopBreakBlock); + result = builder.emitLoad(resultVar); } - auto result = builder.emitMakeArray(arrayType, (UInt)args.getCount(), args.getBuffer()); builder.emitReturn(result); return func; } @@ -191,18 +218,43 @@ namespace Slang builder.setInsertInto(func); builder.emitBlock(); auto originalParam = builder.emitParam(arrayType); + IRInst* packedArray = nullptr; auto count = getIntVal(arrayType->getElementCount()); - List args; - args.setCount((Index)count); - for (IRIntegerValue ii = 0; ii < count; ++ii) + if (count <= kMaxArraySizeToUnroll) + { + // If the array is small enough, just process each element directly. + List args; + args.setCount((Index)count); + for (IRIntegerValue ii = 0; ii < count; ++ii) + { + auto originalElement = builder.emitElementExtract(originalParam, ii); + auto packedElement = innerTypeInfo.convertOriginalToLowered + ? builder.emitCallInst(innerTypeInfo.loweredType, innerTypeInfo.convertOriginalToLowered, 1, &originalElement) + : originalElement; + args[(Index)ii] = packedElement; + } + packedArray = builder.emitMakeArray(innerArrayType, (UInt)args.getCount(), args.getBuffer()); + } + else { - auto originalElement = builder.emitElementExtract(originalParam, ii); + // The general case for large arrays is to emit a loop through the elements. + IRVar* packedArrayVar = builder.emitVar(innerArrayType); + IRBlock* loopBodyBlock; + IRBlock* loopBreakBlock; + auto loopParam = emitLoopBlocks(&builder, builder.getIntValue(builder.getIntType(), 0), builder.getIntValue(builder.getIntType(), count), + loopBodyBlock, loopBreakBlock); + + builder.setInsertBefore(loopBodyBlock->getFirstOrdinaryInst()); + auto originalElement = builder.emitElementExtract(originalParam, loopParam); auto packedElement = innerTypeInfo.convertOriginalToLowered ? builder.emitCallInst(innerTypeInfo.loweredType, innerTypeInfo.convertOriginalToLowered, 1, &originalElement) : originalElement; - args[(Index)ii] = packedElement; + auto varPtr = builder.emitElementAddress(packedArrayVar, loopParam); + builder.emitStore(varPtr, packedElement); + builder.setInsertInto(loopBreakBlock); + packedArray = builder.emitLoad(packedArrayVar); } - auto packedArray = builder.emitMakeArray(innerArrayType, (UInt)args.getCount(), args.getBuffer()); + auto result = builder.emitMakeStruct(structType, 1, &packedArray); builder.emitReturn(result); return func; diff --git a/source/slang/slang-ir-util.cpp b/source/slang/slang-ir-util.cpp index f6b0acaed9..bcb9439fbb 100644 --- a/source/slang/slang-ir-util.cpp +++ b/source/slang/slang-ir-util.cpp @@ -872,18 +872,21 @@ IRInst* emitLoopBlocks(IRBuilder* builder, IRInst* initVal, IRInst* finalVal, IR IRBuilder loopBuilder = *builder; auto loopHeadBlock = loopBuilder.emitBlock(); loopBodyBlock = loopBuilder.emitBlock(); + auto ifBreakBlock = loopBuilder.emitBlock(); loopBreakBlock = loopBuilder.emitBlock(); auto loopContinueBlock = loopBuilder.emitBlock(); builder->emitLoop(loopHeadBlock, loopBreakBlock, loopHeadBlock, 1, &initVal); loopBuilder.setInsertInto(loopHeadBlock); auto loopParam = loopBuilder.emitParam(initVal->getFullType()); auto cmpResult = loopBuilder.emitLess(loopParam, finalVal); - loopBuilder.emitIfElse(cmpResult, loopBodyBlock, loopBreakBlock, loopBreakBlock); + loopBuilder.emitIfElse(cmpResult, loopBodyBlock, ifBreakBlock, ifBreakBlock); loopBuilder.setInsertInto(loopBodyBlock); loopBuilder.emitBranch(loopContinueBlock); loopBuilder.setInsertInto(loopContinueBlock); auto newParam = loopBuilder.emitAdd(loopParam->getFullType(), loopParam, loopBuilder.getIntValue(loopBuilder.getIntType(), 1)); loopBuilder.emitBranch(loopHeadBlock, 1, &newParam); + loopBuilder.setInsertInto(ifBreakBlock); + loopBuilder.emitBranch(loopBreakBlock); return loopParam; } diff --git a/tests/spirv/large-struct-pack.slang b/tests/spirv/large-struct-pack.slang new file mode 100644 index 0000000000..df15ac67c0 --- /dev/null +++ b/tests/spirv/large-struct-pack.slang @@ -0,0 +1,29 @@ +//TEST:SIMPLE(filecheck=CHECK): -target spirv -emit-spirv-directly -profile glsl_460 +//TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=BUF):-vk -compute +//TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=BUF):-cpu -compute + +// Check that when generating spirv directly, we use a loop +// to copy large arrays in a local variable to a buffer, instead of emitting +// unrolled code that reads each element of the array individually. + +struct WorkData +{ + int B[1024]; +}; + +//TEST_INPUT:set resultBuffer = out ubuffer(data=[0 0 0 0], stride=4, count=1024) +RWStructuredBuffer resultBuffer; + +// CHECK: OpLoopMerge +// CHECK: OpLoopMerge + +// BUF: 0 +// BUF: 1 +[numthreads(1, 1, 1)] +void computeMain(uint3 tid: SV_DispatchThreadID) +{ + WorkData wd; + for (int i = 0; i < 1024; i++) + wd.B[i] = i; + resultBuffer[0] = wd; +} diff --git a/tests/spirv/large-struct-ptr.slang b/tests/spirv/large-struct-ptr.slang new file mode 100644 index 0000000000..131ccdbb9c --- /dev/null +++ b/tests/spirv/large-struct-ptr.slang @@ -0,0 +1,20 @@ +//TEST:SIMPLE(filecheck=CHECK): -target spirv -emit-spirv-directly -profile glsl_460 + +struct WorkData { + float A[2048 * 2048]; + float B[2048 * 2048]; +}; +struct PushData { + WorkData* Input; + float* Dest; +}; + +[vk::push_constant] ConstantBuffer cb; + +// CHECK: OpEntryPoint + +[numthreads(64, 1, 1)] +void ComputeMain(uint tid: SV_DispatchThreadID) +{ + cb.Dest[tid] = cb.Input->A[tid] * cb.Input->B[tid]; +} \ No newline at end of file diff --git a/tests/spirv/large-struct.slang b/tests/spirv/large-struct.slang new file mode 100644 index 0000000000..7738a5fcf3 --- /dev/null +++ b/tests/spirv/large-struct.slang @@ -0,0 +1,32 @@ +//TEST:SIMPLE(filecheck=CHECK): -target spirv -emit-spirv-directly -profile glsl_460 +//TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=BUF):-vk -compute -output-using-type +//TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=BUF):-d3d12 -compute -output-using-type +//TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=BUF):-cpu -compute -output-using-type + +// Check that when generating spirv directly, we use a loop +// to copy large arrays in input data out into a local variable, instead of emitting +// unrolled code that reads each element of the array individually. + +struct WorkData +{ + float A[2*2]; + float B[1024]; + + float Foo(uint i) { return A[i] * B[i]; } +}; + +//TEST_INPUT:set input = new WorkData{[1.0, 2.0, 3.0, 4.0], [10.0, 20.0, 30.0, 40.0]} +ConstantBuffer input; + +//TEST_INPUT:set resultBuffer = out ubuffer(data=[0 0 0 0], stride=4) +RWStructuredBuffer resultBuffer; + +// CHECK: OpLoopMerge + +[numthreads(2, 1, 1)] +void computeMain(uint3 tid: SV_DispatchThreadID) +{ + // BUF: 10.0 + // BUF: 40.0 + resultBuffer[tid.x] = input.Foo(tid.x); +} \ No newline at end of file diff --git a/tools/render-test/render-test-main.cpp b/tools/render-test/render-test-main.cpp index 02c0ea86a7..fcdd4b54dd 100644 --- a/tools/render-test/render-test-main.cpp +++ b/tools/render-test/render-test-main.cpp @@ -207,7 +207,10 @@ struct AssignValsFromLayoutContext { const InputBufferDesc& srcBuffer = srcVal->bufferDesc; auto& bufferData = srcVal->bufferData; - const size_t bufferSize = bufferData.getCount() * sizeof(uint32_t); + const size_t bufferSize = Math::Max((size_t)bufferData.getCount() * sizeof(uint32_t), (size_t)(srcBuffer.elementCount * srcBuffer.stride)); + bufferData.reserve(bufferSize / sizeof(uint32_t)); + for (size_t i = bufferData.getCount(); i < bufferSize / sizeof(uint32_t); i++) + bufferData.add(0); ComPtr bufferResource; SLANG_RETURN_ON_FAIL(ShaderRendererUtil::createBufferResource(srcBuffer, /*entry.isOutput,*/ bufferSize, bufferData.getBuffer(), device, bufferResource)); @@ -232,6 +235,7 @@ struct AssignValsFromLayoutContext const InputBufferDesc& counterBufferDesc{ InputBufferType::StorageBuffer, sizeof(uint32_t), + 1, Format::Unknown, }; SLANG_RETURN_ON_FAIL(ShaderRendererUtil::createBufferResource( diff --git a/tools/render-test/shader-input-layout.cpp b/tools/render-test/shader-input-layout.cpp index 96f5db6e05..3012d45a44 100644 --- a/tools/render-test/shader-input-layout.cpp +++ b/tools/render-test/shader-input-layout.cpp @@ -222,6 +222,11 @@ namespace renderer_test parser.Read("="); val->bufferDesc.stride = parser.ReadInt(); } + else if (word == "count") + { + parser.Read("="); + val->bufferDesc.elementCount = parser.ReadInt(); + } else if (word == "counter") { parser.Read("="); diff --git a/tools/render-test/shader-input-layout.h b/tools/render-test/shader-input-layout.h index de1da3da9d..996635b94c 100644 --- a/tools/render-test/shader-input-layout.h +++ b/tools/render-test/shader-input-layout.h @@ -68,6 +68,7 @@ struct InputBufferDesc { InputBufferType type = InputBufferType::StorageBuffer; int stride = 0; // stride == 0 indicates an unstructured buffer. + int elementCount = 1; Format format = Format::Unknown; // For RWStructuredBuffer, AppendStructuredBuffer, ConsumeStructuredBuffer // the default value of 0xffffffff indicates that a counter buffer should