Skip to content

Commit

Permalink
SPIRV: Fix performance issue when handling large arrays. (shader-slan…
Browse files Browse the repository at this point in the history
…g#4064)

* SPIRV: Fix performance issue when handling large arrays.

* Add test for packing.

* Fix clang.
  • Loading branch information
csyonghe authored May 1, 2024
1 parent 4533c82 commit 0bb826f
Show file tree
Hide file tree
Showing 8 changed files with 160 additions and 14 deletions.
76 changes: 64 additions & 12 deletions source/slang/slang-ir-lower-buffer-element-type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ namespace Slang
{
struct LoweredElementTypeContext
{
static const IRIntegerValue kMaxArraySizeToUnroll = 32;

struct LoweredElementTypeInfo
{
IRType* originalType;
Expand Down Expand Up @@ -161,17 +163,42 @@ namespace Slang
auto packedParam = builder.emitParam(structType);
auto packedArray = builder.emitFieldExtract(innerArrayType, packedParam, dataKey);
auto count = getIntVal(arrayType->getElementCount());
List<IRInst*> args;
args.setCount((Index)count);
for (IRIntegerValue ii = 0; ii < count; ++ii)
IRInst* result = nullptr;
if (count <= kMaxArraySizeToUnroll)
{
// If the array is small enough, just process each element directly.
List<IRInst*> args;
args.setCount((Index)count);
for (IRIntegerValue ii = 0; ii < count; ++ii)
{
auto packedElement = builder.emitElementExtract(packedArray, ii);
auto originalElement = innerTypeInfo.convertLoweredToOriginal
? builder.emitCallInst(innerTypeInfo.originalType, innerTypeInfo.convertLoweredToOriginal, 1, &packedElement)
: packedElement;
args[(Index)ii] = originalElement;
}
result = builder.emitMakeArray(arrayType, (UInt)args.getCount(), args.getBuffer());

}
else
{
auto packedElement = builder.emitElementExtract(packedArray, ii);
// The general case for large arrays is to emit a loop through the elements.
IRVar* resultVar = builder.emitVar(arrayType);
IRBlock* loopBodyBlock;
IRBlock* loopBreakBlock;
auto loopParam = emitLoopBlocks(&builder, builder.getIntValue(builder.getIntType(), 0), builder.getIntValue(builder.getIntType(), count),
loopBodyBlock, loopBreakBlock);

builder.setInsertBefore(loopBodyBlock->getFirstOrdinaryInst());
auto packedElement = builder.emitElementExtract(packedArray, loopParam);
auto originalElement = innerTypeInfo.convertLoweredToOriginal
? builder.emitCallInst(innerTypeInfo.originalType, innerTypeInfo.convertLoweredToOriginal, 1, &packedElement)
: packedElement;
args[(Index)ii] = originalElement;
auto varPtr = builder.emitElementAddress(resultVar, loopParam);
builder.emitStore(varPtr, originalElement);
builder.setInsertInto(loopBreakBlock);
result = builder.emitLoad(resultVar);
}
auto result = builder.emitMakeArray(arrayType, (UInt)args.getCount(), args.getBuffer());
builder.emitReturn(result);
return func;
}
Expand All @@ -191,18 +218,43 @@ namespace Slang
builder.setInsertInto(func);
builder.emitBlock();
auto originalParam = builder.emitParam(arrayType);
IRInst* packedArray = nullptr;
auto count = getIntVal(arrayType->getElementCount());
List<IRInst*> args;
args.setCount((Index)count);
for (IRIntegerValue ii = 0; ii < count; ++ii)
if (count <= kMaxArraySizeToUnroll)
{
// If the array is small enough, just process each element directly.
List<IRInst*> args;
args.setCount((Index)count);
for (IRIntegerValue ii = 0; ii < count; ++ii)
{
auto originalElement = builder.emitElementExtract(originalParam, ii);
auto packedElement = innerTypeInfo.convertOriginalToLowered
? builder.emitCallInst(innerTypeInfo.loweredType, innerTypeInfo.convertOriginalToLowered, 1, &originalElement)
: originalElement;
args[(Index)ii] = packedElement;
}
packedArray = builder.emitMakeArray(innerArrayType, (UInt)args.getCount(), args.getBuffer());
}
else
{
auto originalElement = builder.emitElementExtract(originalParam, ii);
// The general case for large arrays is to emit a loop through the elements.
IRVar* packedArrayVar = builder.emitVar(innerArrayType);
IRBlock* loopBodyBlock;
IRBlock* loopBreakBlock;
auto loopParam = emitLoopBlocks(&builder, builder.getIntValue(builder.getIntType(), 0), builder.getIntValue(builder.getIntType(), count),
loopBodyBlock, loopBreakBlock);

builder.setInsertBefore(loopBodyBlock->getFirstOrdinaryInst());
auto originalElement = builder.emitElementExtract(originalParam, loopParam);
auto packedElement = innerTypeInfo.convertOriginalToLowered
? builder.emitCallInst(innerTypeInfo.loweredType, innerTypeInfo.convertOriginalToLowered, 1, &originalElement)
: originalElement;
args[(Index)ii] = packedElement;
auto varPtr = builder.emitElementAddress(packedArrayVar, loopParam);
builder.emitStore(varPtr, packedElement);
builder.setInsertInto(loopBreakBlock);
packedArray = builder.emitLoad(packedArrayVar);
}
auto packedArray = builder.emitMakeArray(innerArrayType, (UInt)args.getCount(), args.getBuffer());

auto result = builder.emitMakeStruct(structType, 1, &packedArray);
builder.emitReturn(result);
return func;
Expand Down
5 changes: 4 additions & 1 deletion source/slang/slang-ir-util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -872,18 +872,21 @@ IRInst* emitLoopBlocks(IRBuilder* builder, IRInst* initVal, IRInst* finalVal, IR
IRBuilder loopBuilder = *builder;
auto loopHeadBlock = loopBuilder.emitBlock();
loopBodyBlock = loopBuilder.emitBlock();
auto ifBreakBlock = loopBuilder.emitBlock();
loopBreakBlock = loopBuilder.emitBlock();
auto loopContinueBlock = loopBuilder.emitBlock();
builder->emitLoop(loopHeadBlock, loopBreakBlock, loopHeadBlock, 1, &initVal);
loopBuilder.setInsertInto(loopHeadBlock);
auto loopParam = loopBuilder.emitParam(initVal->getFullType());
auto cmpResult = loopBuilder.emitLess(loopParam, finalVal);
loopBuilder.emitIfElse(cmpResult, loopBodyBlock, loopBreakBlock, loopBreakBlock);
loopBuilder.emitIfElse(cmpResult, loopBodyBlock, ifBreakBlock, ifBreakBlock);
loopBuilder.setInsertInto(loopBodyBlock);
loopBuilder.emitBranch(loopContinueBlock);
loopBuilder.setInsertInto(loopContinueBlock);
auto newParam = loopBuilder.emitAdd(loopParam->getFullType(), loopParam, loopBuilder.getIntValue(loopBuilder.getIntType(), 1));
loopBuilder.emitBranch(loopHeadBlock, 1, &newParam);
loopBuilder.setInsertInto(ifBreakBlock);
loopBuilder.emitBranch(loopBreakBlock);
return loopParam;
}

Expand Down
29 changes: 29 additions & 0 deletions tests/spirv/large-struct-pack.slang
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
//TEST:SIMPLE(filecheck=CHECK): -target spirv -emit-spirv-directly -profile glsl_460
//TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=BUF):-vk -compute
//TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=BUF):-cpu -compute

// Check that when generating spirv directly, we use a loop
// to copy large arrays in a local variable to a buffer, instead of emitting
// unrolled code that reads each element of the array individually.

struct WorkData
{
int B[1024];
};

//TEST_INPUT:set resultBuffer = out ubuffer(data=[0 0 0 0], stride=4, count=1024)
RWStructuredBuffer<WorkData> resultBuffer;

// CHECK: OpLoopMerge
// CHECK: OpLoopMerge

// BUF: 0
// BUF: 1
[numthreads(1, 1, 1)]
void computeMain(uint3 tid: SV_DispatchThreadID)
{
WorkData wd;
for (int i = 0; i < 1024; i++)
wd.B[i] = i;
resultBuffer[0] = wd;
}
20 changes: 20 additions & 0 deletions tests/spirv/large-struct-ptr.slang
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
//TEST:SIMPLE(filecheck=CHECK): -target spirv -emit-spirv-directly -profile glsl_460

struct WorkData {
float A[2048 * 2048];
float B[2048 * 2048];
};
struct PushData {
WorkData* Input;
float* Dest;
};

[vk::push_constant] ConstantBuffer<PushData> cb;

// CHECK: OpEntryPoint

[numthreads(64, 1, 1)]
void ComputeMain(uint tid: SV_DispatchThreadID)
{
cb.Dest[tid] = cb.Input->A[tid] * cb.Input->B[tid];
}
32 changes: 32 additions & 0 deletions tests/spirv/large-struct.slang
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
//TEST:SIMPLE(filecheck=CHECK): -target spirv -emit-spirv-directly -profile glsl_460
//TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=BUF):-vk -compute -output-using-type
//TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=BUF):-d3d12 -compute -output-using-type
//TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=BUF):-cpu -compute -output-using-type

// Check that when generating spirv directly, we use a loop
// to copy large arrays in input data out into a local variable, instead of emitting
// unrolled code that reads each element of the array individually.

struct WorkData
{
float A[2*2];
float B[1024];

float Foo(uint i) { return A[i] * B[i]; }
};

//TEST_INPUT:set input = new WorkData{[1.0, 2.0, 3.0, 4.0], [10.0, 20.0, 30.0, 40.0]}
ConstantBuffer<WorkData> input;

//TEST_INPUT:set resultBuffer = out ubuffer(data=[0 0 0 0], stride=4)
RWStructuredBuffer<float> resultBuffer;

// CHECK: OpLoopMerge

[numthreads(2, 1, 1)]
void computeMain(uint3 tid: SV_DispatchThreadID)
{
// BUF: 10.0
// BUF: 40.0
resultBuffer[tid.x] = input.Foo(tid.x);
}
6 changes: 5 additions & 1 deletion tools/render-test/render-test-main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,10 @@ struct AssignValsFromLayoutContext
{
const InputBufferDesc& srcBuffer = srcVal->bufferDesc;
auto& bufferData = srcVal->bufferData;
const size_t bufferSize = bufferData.getCount() * sizeof(uint32_t);
const size_t bufferSize = Math::Max((size_t)bufferData.getCount() * sizeof(uint32_t), (size_t)(srcBuffer.elementCount * srcBuffer.stride));
bufferData.reserve(bufferSize / sizeof(uint32_t));
for (size_t i = bufferData.getCount(); i < bufferSize / sizeof(uint32_t); i++)
bufferData.add(0);

ComPtr<IBufferResource> bufferResource;
SLANG_RETURN_ON_FAIL(ShaderRendererUtil::createBufferResource(srcBuffer, /*entry.isOutput,*/ bufferSize, bufferData.getBuffer(), device, bufferResource));
Expand All @@ -232,6 +235,7 @@ struct AssignValsFromLayoutContext
const InputBufferDesc& counterBufferDesc{
InputBufferType::StorageBuffer,
sizeof(uint32_t),
1,
Format::Unknown,
};
SLANG_RETURN_ON_FAIL(ShaderRendererUtil::createBufferResource(
Expand Down
5 changes: 5 additions & 0 deletions tools/render-test/shader-input-layout.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,11 @@ namespace renderer_test
parser.Read("=");
val->bufferDesc.stride = parser.ReadInt();
}
else if (word == "count")
{
parser.Read("=");
val->bufferDesc.elementCount = parser.ReadInt();
}
else if (word == "counter")
{
parser.Read("=");
Expand Down
1 change: 1 addition & 0 deletions tools/render-test/shader-input-layout.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ struct InputBufferDesc
{
InputBufferType type = InputBufferType::StorageBuffer;
int stride = 0; // stride == 0 indicates an unstructured buffer.
int elementCount = 1;
Format format = Format::Unknown;
// For RWStructuredBuffer, AppendStructuredBuffer, ConsumeStructuredBuffer
// the default value of 0xffffffff indicates that a counter buffer should
Expand Down

0 comments on commit 0bb826f

Please sign in to comment.