Skip to content

Commit

Permalink
Do recursive function checks early during IR linking (shader-slang#5777)
Browse files Browse the repository at this point in the history
  • Loading branch information
fairywreath authored Dec 6, 2024
1 parent d4136c9 commit ecc5a39
Show file tree
Hide file tree
Showing 6 changed files with 133 additions and 111 deletions.
3 changes: 2 additions & 1 deletion source/slang/slang-emit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
#include "slang-ir-autodiff.h"
#include "slang-ir-bind-existentials.h"
#include "slang-ir-byte-address-legalize.h"
#include "slang-ir-check-recursive-type.h"
#include "slang-ir-check-recursion.h"
#include "slang-ir-check-shader-parameter-type.h"
#include "slang-ir-check-unsupported-inst.h"
#include "slang-ir-cleanup-void.h"
Expand Down Expand Up @@ -884,6 +884,7 @@ Result linkAndOptimizeIR(
if (targetProgram->getOptionSet().shouldRunNonEssentialValidation())
{
checkForRecursiveTypes(irModule, sink);
checkForRecursiveFunctions(codeGenContext->getTargetReq(), irModule, sink);

// For some targets, we are more restrictive about what types are allowed
// to be used as shader parameters in ConstantBuffer/ParameterBlock.
Expand Down
126 changes: 126 additions & 0 deletions source/slang/slang-ir-check-recursion.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
#include "slang-ir-check-recursion.h"

#include "slang-ir-util.h"

namespace Slang
{
bool checkTypeRecursionImpl(
HashSet<IRInst*>& checkedTypes,
HashSet<IRInst*>& stack,
IRInst* type,
IRInst* field,
DiagnosticSink* sink)
{
auto visitElementType = [&](IRInst* elementType, IRInst* field) -> bool
{
if (!stack.add(elementType))
{
sink->diagnose(field ? field : type, Diagnostics::recursiveType, type);
return false;
}
if (checkedTypes.add(elementType))
checkTypeRecursionImpl(checkedTypes, stack, elementType, field, sink);
stack.remove(elementType);
return true;
};
if (auto arrayType = as<IRArrayTypeBase>(type))
{
return visitElementType(arrayType->getElementType(), field);
}
else if (auto structType = as<IRStructType>(type))
{
for (auto sfield : structType->getFields())
if (!visitElementType(sfield->getFieldType(), sfield))
return false;
}
return true;
}

void checkTypeRecursion(HashSet<IRInst*>& checkedTypes, IRInst* type, DiagnosticSink* sink)
{
HashSet<IRInst*> stack;
if (checkedTypes.add(type))
{
stack.add(type);
checkTypeRecursionImpl(checkedTypes, stack, type, nullptr, sink);
}
}

void checkForRecursiveTypes(IRModule* module, DiagnosticSink* sink)
{
HashSet<IRInst*> checkedTypes;
for (auto globalInst : module->getGlobalInsts())
{
switch (globalInst->getOp())
{
case kIROp_StructType:
{
checkTypeRecursion(checkedTypes, globalInst, sink);
}
break;
default:
break;
}
}
}

bool checkFunctionRecursionImpl(
HashSet<IRFunc*>& checkedFuncs,
HashSet<IRFunc*>& callStack,
IRFunc* func,
DiagnosticSink* sink)
{
for (auto block : func->getBlocks())
{
for (auto inst : block->getChildren())
{
auto callInst = as<IRCall>(inst);
if (!callInst)
continue;
auto callee = as<IRFunc>(callInst->getCallee());
if (!callee)
continue;
if (!callStack.add(callee))
{
sink->diagnose(callInst, Diagnostics::unsupportedRecursion, callee);
return false;
}
if (checkedFuncs.add(callee))
checkFunctionRecursionImpl(checkedFuncs, callStack, callee, sink);
callStack.remove(callee);
}
}
return true;
}

void checkFunctionRecursion(HashSet<IRFunc*>& checkedFuncs, IRFunc* func, DiagnosticSink* sink)
{
HashSet<IRFunc*> callStack;
if (checkedFuncs.add(func))
{
callStack.add(func);
checkFunctionRecursionImpl(checkedFuncs, callStack, func, sink);
}
}

void checkForRecursiveFunctions(TargetRequest* target, IRModule* module, DiagnosticSink* sink)
{
HashSet<IRFunc*> checkedFuncsForRecursionDetection;
for (auto globalInst : module->getGlobalInsts())
{
switch (globalInst->getOp())
{
case kIROp_Func:
if (!isCPUTarget(target))
checkFunctionRecursion(
checkedFuncsForRecursionDetection,
as<IRFunc>(globalInst),
sink);
break;
default:
break;
}
}
}

} // namespace Slang
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@ namespace Slang
{
struct IRModule;
class DiagnosticSink;
class TargetRequest;

void checkForRecursiveTypes(IRModule* module, DiagnosticSink* sink);

void checkForRecursiveFunctions(TargetRequest* target, IRModule* module, DiagnosticSink* sink);

} // namespace Slang
65 changes: 0 additions & 65 deletions source/slang/slang-ir-check-recursive-type.cpp

This file was deleted.

44 changes: 0 additions & 44 deletions source/slang/slang-ir-check-unsupported-inst.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,46 +5,6 @@

namespace Slang
{
bool isCPUTarget(TargetRequest* targetReq);

bool checkRecursionImpl(
HashSet<IRFunc*>& checkedFuncs,
HashSet<IRFunc*>& callStack,
IRFunc* func,
DiagnosticSink* sink)
{
for (auto block : func->getBlocks())
{
for (auto inst : block->getChildren())
{
auto callInst = as<IRCall>(inst);
if (!callInst)
continue;
auto callee = as<IRFunc>(callInst->getCallee());
if (!callee)
continue;
if (!callStack.add(callee))
{
sink->diagnose(callInst, Diagnostics::unsupportedRecursion, callee);
return false;
}
if (checkedFuncs.add(callee))
checkRecursionImpl(checkedFuncs, callStack, callee, sink);
callStack.remove(callee);
}
}
return true;
}

void checkRecursion(HashSet<IRFunc*>& checkedFuncs, IRFunc* func, DiagnosticSink* sink)
{
HashSet<IRFunc*> callStack;
if (checkedFuncs.add(func))
{
callStack.add(func);
checkRecursionImpl(checkedFuncs, callStack, func, sink);
}
}

void checkUnsupportedInst(TargetRequest* target, IRFunc* func, DiagnosticSink* sink)
{
Expand All @@ -65,8 +25,6 @@ void checkUnsupportedInst(TargetRequest* target, IRFunc* func, DiagnosticSink* s

void checkUnsupportedInst(TargetRequest* target, IRModule* module, DiagnosticSink* sink)
{
HashSet<IRFunc*> checkedFuncsForRecursionDetection;

for (auto globalInst : module->getGlobalInsts())
{
switch (globalInst->getOp())
Expand All @@ -84,8 +42,6 @@ void checkUnsupportedInst(TargetRequest* target, IRModule* module, DiagnosticSin
break;
}
case kIROp_Func:
if (!isCPUTarget(target))
checkRecursion(checkedFuncsForRecursionDetection, as<IRFunc>(globalInst), sink);
checkUnsupportedInst(target, as<IRFunc>(globalInst), sink);
break;
case kIROp_Generic:
Expand Down
2 changes: 1 addition & 1 deletion source/slang/slang-lower-to-ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
#include "slang-ir-autodiff.h"
#include "slang-ir-bit-field-accessors.h"
#include "slang-ir-check-differentiability.h"
#include "slang-ir-check-recursive-type.h"
#include "slang-ir-check-recursion.h"
#include "slang-ir-clone.h"
#include "slang-ir-constexpr.h"
#include "slang-ir-dce.h"
Expand Down

0 comments on commit ecc5a39

Please sign in to comment.