Skip to content

Commit

Permalink
WIP: Force Inline If RefType (shader-slang#4005)
Browse files Browse the repository at this point in the history
* Force Inline if reftype

Fixes shader-slang#3997.
If we are using a refType, we now ForceInline.

remarks:
1. Modifications were made in slang-ir-glsl-legalize to change how we translate GlobalParam proxy's into GlobalParam.
 a. We now handle the senario where a globalParam is used in multiple disjoint blocks (like 2 different functions).

* try to figure out why CI fails but local works

try to inline DispatchMesh, works locally, may fail on CI(?)

* try another fix

* add task tests + don't allow semi-early task-shader inline

Task shader uses DispatchMesh which is a very big 'hack' where we check for the function name and modify the callees in very large ways. This function does inline, but it cannot inline early due to future mangling that this operation requires todo. This is reflected with the `[noRefInline]` modifier. It is a modifier so users may stop mandatory inlines with `__ref` parameter.
  • Loading branch information
ArielG-NV authored Apr 26, 2024
1 parent bc7231b commit e91bd3b
Show file tree
Hide file tree
Showing 13 changed files with 108 additions and 35 deletions.
5 changes: 4 additions & 1 deletion source/slang/core.meta.slang
Original file line number Diff line number Diff line change
Expand Up @@ -2636,4 +2636,7 @@ __attributeTarget(FuncDecl)
attribute_syntax [DerivativeGroupQuad] : DerivativeGroupQuadAttribute;

__attributeTarget(FuncDecl)
attribute_syntax [DerivativeGroupLinear] : DerivativeGroupLinearAttribute;
attribute_syntax [DerivativeGroupLinear] : DerivativeGroupLinearAttribute;

__attributeTarget(FuncDecl)
attribute_syntax [noRefInline] : NoRefInlineAttribute;
3 changes: 3 additions & 0 deletions source/slang/hlsl.meta.slang
Original file line number Diff line number Diff line change
Expand Up @@ -11580,8 +11580,11 @@ void SetMeshOutputCounts(uint vertexCount, uint primitiveCount)
//
// This function doesn't return.
//
// This function cannot be inlined due to a legalization pass happening mid-way through processing
// and later more processing happening to the function which requires eventual inlining.
[KnownBuiltin("DispatchMesh")]
[require(glsl_hlsl_spirv, meshshading)]
[noRefInline]
void DispatchMesh<P>(uint threadGroupCountX, uint threadGroupCountY, uint threadGroupCountZ, __ref P meshPayload)
{
__target_switch
Expand Down
7 changes: 7 additions & 0 deletions source/slang/slang-ast-modifier.h
Original file line number Diff line number Diff line change
Expand Up @@ -1447,6 +1447,13 @@ class NoInlineAttribute : public Attribute
SLANG_AST_CLASS(NoInlineAttribute)
};

/// A `[noRefInline]` attribute represents a request to not force inline a
/// function specifically due to a refType parameter.
class NoRefInlineAttribute : public Attribute
{
SLANG_AST_CLASS(NoRefInlineAttribute)
};

class DerivativeGroupQuadAttribute : public Attribute
{
SLANG_AST_CLASS(DerivativeGroupQuadAttribute)
Expand Down
2 changes: 1 addition & 1 deletion source/slang/slang-emit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -492,7 +492,7 @@ Result linkAndOptimizeIR(
{
// We could fail because
// 1) It's not inlinable for some reason (for example if it's recursive)
SLANG_RETURN_ON_FAIL(performStringInlining(irModule, sink));
SLANG_RETURN_ON_FAIL(performTypeInlining(irModule, sink));
}

lowerReinterpret(targetProgram, irModule, sink);
Expand Down
64 changes: 42 additions & 22 deletions source/slang/slang-ir-glsl-legalize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2725,19 +2725,18 @@ void legalizeEntryPointParameterForGLSL(
codeGenContext,
builder, paramType, paramLayout, LayoutResourceKind::VaryingInput, stage, pp);

// Next we need to replace uses of the parameter with
// references to the variable(s). We are going to do that
// somewhat naively, by simply materializing the
// variables at the start.
// we have a simple struct which represents all materialized GlobalParams, this
// struct will replace the no longer needed global variable which proxied as a
// GlobalParam.
IRInst* materialized = materializeValue(builder, globalValue);

// We next need to replace all uses of the proxy variable with the actual GlobalParam
pp->replaceUsesWith(materialized);

// We finally need to replace all global variable references of a global
// parameter with the actual global parameter for all function calls.
// Global parameters are used with a OpStore to copy its data into a global
// variable intermediary. We will follow the uses of a global parameter until
// we find this OpStore, then we will replace uses of the intermediary object.
// GlobalParams use use a OpStore to copy its data into a global
// variable intermediary. We will follow the uses of this intermediary
// and replace all some of the uses (function calls and SPIRV Operands)
Dictionary<IRBlock*, IRInst*> blockToMaterialized;
IRBuilder replaceBuilder(materialized);
for (auto dec : pp->getDecorations())
{
Expand All @@ -2747,27 +2746,48 @@ void legalizeEntryPointParameterForGLSL(
auto globalVarType = cast<IRPtrTypeBase>(globalVar->getDataType())->getValueType();
auto key = dec->getOperand(1);

// we will be replacing uses of `globalVarToReplace`, we need globalVarToReplaceNextUse
// to catch the next use before it is removed from the list of uses
// we will be replacing uses of `globalVarToReplace`. We need globalVarToReplaceNextUse
// to catch the next use before it is removed from the list of uses.
IRUse* globalVarToReplaceNextUse;
for (auto globalVarUse = globalVar->firstUse; globalVarUse; globalVarUse = globalVarToReplaceNextUse)
{
globalVarToReplaceNextUse = globalVarUse->nextUse;
auto user = globalVarUse->getUser();
if (user->getOp() != kIROp_Call)
continue;
for (Slang::UInt operandIndex = 0; operandIndex < user->getOperandCount();
operandIndex++)
switch (user->getOp())
{
auto operand = user->getOperand(operandIndex);
auto operandUse = user->getOperands() + operandIndex;
if (operand != globalVar)
continue;
replaceBuilder.setInsertBefore(user);
auto field = replaceBuilder.emitFieldExtract(globalVarType, materialized, key);
replaceBuilder.replaceOperand(operandUse, field);
case kIROp_SPIRVAsmOperandInst:
case kIROp_Call:
{
for (Slang::UInt operandIndex = 0; operandIndex < user->getOperandCount();
operandIndex++)
{
auto operand = user->getOperand(operandIndex);
auto operandUse = user->getOperands() + operandIndex;
if (operand != globalVar)
continue;

// a GlobalParam may be used across functions/blocks, we need to
// materialize at a minimum 1 struct per block.
auto callingBlock = getBlock(user);
bool found = blockToMaterialized.tryGetValue(callingBlock, materialized);
if (!found)
{
replaceBuilder.setInsertBefore(callingBlock->getFirstInst());
materialized = materializeValue(&replaceBuilder, globalValue);
blockToMaterialized.set(callingBlock, materialized);
}

replaceBuilder.setInsertBefore(user);
auto field = replaceBuilder.emitFieldExtract(globalVarType, materialized, key);
replaceBuilder.replaceOperand(operandUse, field);
break;
}
break;
}
default:
break;
}
continue;
}
}
}
Expand Down
20 changes: 13 additions & 7 deletions source/slang/slang-ir-inline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -709,15 +709,15 @@ void performMandatoryEarlyInlining(IRModule* module)
namespace { // anonymous

// Inlines calls that involve String types
struct StringInliningPass : InliningPassBase
struct TypeInliningPass : InliningPassBase
{
typedef InliningPassBase Super;

StringInliningPass(IRModule* module)
TypeInliningPass(IRModule* module)
: Super(module)
{}

bool doesTypeRequireInline(IRType* type)
bool doesTypeRequireInline(IRType* type, IRFunc* callee)
{
// TODO(JS):
// I guess there is a question here about what type around string requires
Expand All @@ -727,6 +727,12 @@ struct StringInliningPass : InliningPassBase
const auto op = type->getOp();
switch (op)
{
case kIROp_RefType:
{
if(callee->findDecoration<IRNoRefInlineDecoration>())
return false;
return true;
}
case kIROp_StringType:
case kIROp_NativeStringType:
{
Expand All @@ -742,15 +748,15 @@ struct StringInliningPass : InliningPassBase
{
auto callee = info.callee;

if (doesTypeRequireInline(callee->getResultType()))
if (doesTypeRequireInline(callee->getResultType(), callee))
{
return true;
}

const auto count = Count(callee->getParamCount());
for (Index i = 0; i < count; ++i)
{
if (doesTypeRequireInline(callee->getParamType(UInt(i))))
if (doesTypeRequireInline(callee->getParamType(UInt(i)), callee))
{
return true;
}
Expand All @@ -762,7 +768,7 @@ struct StringInliningPass : InliningPassBase

} // anonymous

Result performStringInlining(IRModule* module, DiagnosticSink* sink)
Result performTypeInlining(IRModule* module, DiagnosticSink* sink)
{
SLANG_UNUSED(sink);

Expand All @@ -780,7 +786,7 @@ Result performStringInlining(IRModule* module, DiagnosticSink* sink)
//
while(true)
{
StringInliningPass pass(module);
TypeInliningPass pass(module);
if (pass.considerAllCallSites())
{
// If there was a change try inlining again
Expand Down
4 changes: 2 additions & 2 deletions source/slang/slang-ir-inline.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ namespace Slang
struct IRGlobalValueWithCode;
class DiagnosticSink;

/// Any call to a function that takes or returns a string parameter is inlined
Result performStringInlining(IRModule* module, DiagnosticSink* sink);
/// Any call to a function that takes or returns a string/RefType parameter is inlined
Result performTypeInlining(IRModule* module, DiagnosticSink* sink);

/// Inline any call sites to functions marked `[unsafeForceInlineEarly]`
void performMandatoryEarlyInlining(IRModule* module);
Expand Down
1 change: 1 addition & 0 deletions source/slang/slang-ir-inst-defs.h
Original file line number Diff line number Diff line change
Expand Up @@ -862,6 +862,7 @@ INST_RANGE(BindingQuery, GetRegisterIndex, GetRegisterSpace)

/// Applie to an IR function and signals that inlining should not be performed unless unavoidable.
INST(NoInlineDecoration, noInline, 0, 0)
INST(NoRefInlineDecoration, noRefInline, 0, 0)

INST(DerivativeGroupQuadDecoration, DerivativeGroupQuad, 0, 0)
INST(DerivativeGroupLinearDecoration, DerivativeGroupLinear, 0, 0)
Expand Down
1 change: 1 addition & 0 deletions source/slang/slang-ir-insts.h
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,7 @@ IR_SIMPLE_DECORATION(HLSLExportDecoration)
IR_SIMPLE_DECORATION(KeepAliveDecoration)
IR_SIMPLE_DECORATION(RequiresNVAPIDecoration)
IR_SIMPLE_DECORATION(NoInlineDecoration)
IR_SIMPLE_DECORATION(NoRefInlineDecoration)
IR_SIMPLE_DECORATION(DerivativeGroupQuadDecoration)
IR_SIMPLE_DECORATION(DerivativeGroupLinearDecoration)
IR_SIMPLE_DECORATION(AlwaysFoldIntoUseSiteDecoration)
Expand Down
4 changes: 4 additions & 0 deletions source/slang/slang-lower-to-ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9738,6 +9738,10 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
{
derivativeGroupLinearDecor = getBuilder()->addSimpleDecoration<IRDerivativeGroupLinearDecoration>(irFunc);
}
else if (auto noRefInlineAttribute = as<NoRefInlineAttribute>(modifier))
{
getBuilder()->addSimpleDecoration<IRNoRefInlineDecoration>(irFunc);
}
else if (auto instanceAttr = as<InstanceAttribute>(modifier))
{
IRIntLit* intLit = _getIntLitFromAttribute(getBuilder(), instanceAttr);
Expand Down
23 changes: 23 additions & 0 deletions tests/bugs/gh-3997.slang
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
//TEST:SIMPLE(filecheck=CHECK): -target spirv -emit-spirv-directly -O0 -g

//CHECK: OpEntryPoint

float atomicAdd(__ref float value, float amount)
{
__target_switch
{
case cpp:
__requirePrelude("#include <atomic>");
__intrinsic_asm "std::atomic_ref(*$0).fetch_add($1)";
case spirv:
return __atomicAdd(value, amount);
}
}

RWStructuredBuffer<float> outputBuffer;

[numthreads(4, 1, 1)]
[shader("compute")]
void computeMain(int3 dispatchThreadID : SV_DispatchThreadID) {
atomicAdd(outputBuffer[0], 1);
}
4 changes: 2 additions & 2 deletions tests/language-feature/non-copyable-return.slang
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ void computeMain(int3 dispatchThreadID: SV_DispatchThreadID)
{
let f = myFunc0(2.0);
// CHECK: 4.0
// GLSL: void myFunc1_0(float y{{.*}}, spirv_by_reference MyType_0 {{.*}})
// GLSL: void myFunc0_0(float x{{.*}}, spirv_by_reference MyType_0 {{.*}})
// GLSL: main(
// GLSL-NOT: MyType {{.*}} =
outputBuffer[0] = f.x;
}
5 changes: 5 additions & 0 deletions tests/pipeline/rasterization/mesh/task-simple.slang
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK): -task -output-using-type -dx12 -use-dxil -profile sm_6_6 -render-features mesh-shader
//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK): -task -output-using-type -vk -profile glsl_450+spirv_1_4 -render-features mesh-shader
//TEST:SIMPLE(filecheck=HLSL):-target hlsl -entry meshMain -stage mesh
//TEST:SIMPLE(filecheck=CHECK_SPV):-target spirv -entry taskMain -stage amplification

// CHECK_SPV: OpEntryPoint
// CHECK_SPV: TaskPayloadWorkgroupEXT


// To test a simple mesh shader, we'll generate 4 triangles, the vertices of
// each one will hold the triangle index and a value (the square). The fragment
Expand Down

0 comments on commit e91bd3b

Please sign in to comment.