From ecc5a39feecbf73feedf352214406c8752af798a Mon Sep 17 00:00:00 2001 From: Darren Wihandi <65404740+fairywreath@users.noreply.github.com> Date: Thu, 5 Dec 2024 20:09:40 -0500 Subject: [PATCH] Do recursive function checks early during IR linking (#5777) --- source/slang/slang-emit.cpp | 3 +- source/slang/slang-ir-check-recursion.cpp | 126 ++++++++++++++++++ ...sive-type.h => slang-ir-check-recursion.h} | 4 + .../slang/slang-ir-check-recursive-type.cpp | 65 --------- .../slang/slang-ir-check-unsupported-inst.cpp | 44 ------ source/slang/slang-lower-to-ir.cpp | 2 +- 6 files changed, 133 insertions(+), 111 deletions(-) create mode 100644 source/slang/slang-ir-check-recursion.cpp rename source/slang/{slang-ir-check-recursive-type.h => slang-ir-check-recursion.h} (57%) delete mode 100644 source/slang/slang-ir-check-recursive-type.cpp diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp index a9d5c5e508..04ad55c1f7 100644 --- a/source/slang/slang-emit.cpp +++ b/source/slang/slang-emit.cpp @@ -22,7 +22,7 @@ #include "slang-ir-autodiff.h" #include "slang-ir-bind-existentials.h" #include "slang-ir-byte-address-legalize.h" -#include "slang-ir-check-recursive-type.h" +#include "slang-ir-check-recursion.h" #include "slang-ir-check-shader-parameter-type.h" #include "slang-ir-check-unsupported-inst.h" #include "slang-ir-cleanup-void.h" @@ -884,6 +884,7 @@ Result linkAndOptimizeIR( if (targetProgram->getOptionSet().shouldRunNonEssentialValidation()) { checkForRecursiveTypes(irModule, sink); + checkForRecursiveFunctions(codeGenContext->getTargetReq(), irModule, sink); // For some targets, we are more restrictive about what types are allowed // to be used as shader parameters in ConstantBuffer/ParameterBlock. diff --git a/source/slang/slang-ir-check-recursion.cpp b/source/slang/slang-ir-check-recursion.cpp new file mode 100644 index 0000000000..404437a468 --- /dev/null +++ b/source/slang/slang-ir-check-recursion.cpp @@ -0,0 +1,126 @@ +#include "slang-ir-check-recursion.h" + +#include "slang-ir-util.h" + +namespace Slang +{ +bool checkTypeRecursionImpl( + HashSet& checkedTypes, + HashSet& stack, + IRInst* type, + IRInst* field, + DiagnosticSink* sink) +{ + auto visitElementType = [&](IRInst* elementType, IRInst* field) -> bool + { + if (!stack.add(elementType)) + { + sink->diagnose(field ? field : type, Diagnostics::recursiveType, type); + return false; + } + if (checkedTypes.add(elementType)) + checkTypeRecursionImpl(checkedTypes, stack, elementType, field, sink); + stack.remove(elementType); + return true; + }; + if (auto arrayType = as(type)) + { + return visitElementType(arrayType->getElementType(), field); + } + else if (auto structType = as(type)) + { + for (auto sfield : structType->getFields()) + if (!visitElementType(sfield->getFieldType(), sfield)) + return false; + } + return true; +} + +void checkTypeRecursion(HashSet& checkedTypes, IRInst* type, DiagnosticSink* sink) +{ + HashSet stack; + if (checkedTypes.add(type)) + { + stack.add(type); + checkTypeRecursionImpl(checkedTypes, stack, type, nullptr, sink); + } +} + +void checkForRecursiveTypes(IRModule* module, DiagnosticSink* sink) +{ + HashSet checkedTypes; + for (auto globalInst : module->getGlobalInsts()) + { + switch (globalInst->getOp()) + { + case kIROp_StructType: + { + checkTypeRecursion(checkedTypes, globalInst, sink); + } + break; + default: + break; + } + } +} + +bool checkFunctionRecursionImpl( + HashSet& checkedFuncs, + HashSet& callStack, + IRFunc* func, + DiagnosticSink* sink) +{ + for (auto block : func->getBlocks()) + { + for (auto inst : block->getChildren()) + { + auto callInst = as(inst); + if (!callInst) + continue; + auto callee = as(callInst->getCallee()); + if (!callee) + continue; + if (!callStack.add(callee)) + { + sink->diagnose(callInst, Diagnostics::unsupportedRecursion, callee); + return false; + } + if (checkedFuncs.add(callee)) + checkFunctionRecursionImpl(checkedFuncs, callStack, callee, sink); + callStack.remove(callee); + } + } + return true; +} + +void checkFunctionRecursion(HashSet& checkedFuncs, IRFunc* func, DiagnosticSink* sink) +{ + HashSet callStack; + if (checkedFuncs.add(func)) + { + callStack.add(func); + checkFunctionRecursionImpl(checkedFuncs, callStack, func, sink); + } +} + +void checkForRecursiveFunctions(TargetRequest* target, IRModule* module, DiagnosticSink* sink) +{ + HashSet checkedFuncsForRecursionDetection; + for (auto globalInst : module->getGlobalInsts()) + { + switch (globalInst->getOp()) + { + case kIROp_Func: + if (!isCPUTarget(target)) + checkFunctionRecursion( + checkedFuncsForRecursionDetection, + as(globalInst), + sink); + break; + default: + break; + } + } +} + +} // namespace Slang diff --git a/source/slang/slang-ir-check-recursive-type.h b/source/slang/slang-ir-check-recursion.h similarity index 57% rename from source/slang/slang-ir-check-recursive-type.h rename to source/slang/slang-ir-check-recursion.h index dd5796c865..1bfcfbee97 100644 --- a/source/slang/slang-ir-check-recursive-type.h +++ b/source/slang/slang-ir-check-recursion.h @@ -4,6 +4,10 @@ namespace Slang { struct IRModule; class DiagnosticSink; +class TargetRequest; void checkForRecursiveTypes(IRModule* module, DiagnosticSink* sink); + +void checkForRecursiveFunctions(TargetRequest* target, IRModule* module, DiagnosticSink* sink); + } // namespace Slang diff --git a/source/slang/slang-ir-check-recursive-type.cpp b/source/slang/slang-ir-check-recursive-type.cpp deleted file mode 100644 index ee45417358..0000000000 --- a/source/slang/slang-ir-check-recursive-type.cpp +++ /dev/null @@ -1,65 +0,0 @@ -#include "slang-ir-check-recursive-type.h" - -#include "slang-ir-util.h" - -namespace Slang -{ -bool checkTypeRecursionImpl( - HashSet& checkedTypes, - HashSet& stack, - IRInst* type, - IRInst* field, - DiagnosticSink* sink) -{ - auto visitElementType = [&](IRInst* elementType, IRInst* field) -> bool - { - if (!stack.add(elementType)) - { - sink->diagnose(field ? field : type, Diagnostics::recursiveType, type); - return false; - } - if (checkedTypes.add(elementType)) - checkTypeRecursionImpl(checkedTypes, stack, elementType, field, sink); - stack.remove(elementType); - return true; - }; - if (auto arrayType = as(type)) - { - return visitElementType(arrayType->getElementType(), field); - } - else if (auto structType = as(type)) - { - for (auto sfield : structType->getFields()) - if (!visitElementType(sfield->getFieldType(), sfield)) - return false; - } - return true; -} - -void checkTypeRecursion(HashSet& checkedTypes, IRInst* type, DiagnosticSink* sink) -{ - HashSet stack; - if (checkedTypes.add(type)) - { - stack.add(type); - checkTypeRecursionImpl(checkedTypes, stack, type, nullptr, sink); - } -} - -void checkForRecursiveTypes(IRModule* module, DiagnosticSink* sink) -{ - HashSet checkedTypes; - for (auto globalInst : module->getGlobalInsts()) - { - switch (globalInst->getOp()) - { - case kIROp_StructType: - { - checkTypeRecursion(checkedTypes, globalInst, sink); - } - break; - } - } -} - -} // namespace Slang diff --git a/source/slang/slang-ir-check-unsupported-inst.cpp b/source/slang/slang-ir-check-unsupported-inst.cpp index ea9e7cc649..3bf570dc16 100644 --- a/source/slang/slang-ir-check-unsupported-inst.cpp +++ b/source/slang/slang-ir-check-unsupported-inst.cpp @@ -5,46 +5,6 @@ namespace Slang { -bool isCPUTarget(TargetRequest* targetReq); - -bool checkRecursionImpl( - HashSet& checkedFuncs, - HashSet& callStack, - IRFunc* func, - DiagnosticSink* sink) -{ - for (auto block : func->getBlocks()) - { - for (auto inst : block->getChildren()) - { - auto callInst = as(inst); - if (!callInst) - continue; - auto callee = as(callInst->getCallee()); - if (!callee) - continue; - if (!callStack.add(callee)) - { - sink->diagnose(callInst, Diagnostics::unsupportedRecursion, callee); - return false; - } - if (checkedFuncs.add(callee)) - checkRecursionImpl(checkedFuncs, callStack, callee, sink); - callStack.remove(callee); - } - } - return true; -} - -void checkRecursion(HashSet& checkedFuncs, IRFunc* func, DiagnosticSink* sink) -{ - HashSet callStack; - if (checkedFuncs.add(func)) - { - callStack.add(func); - checkRecursionImpl(checkedFuncs, callStack, func, sink); - } -} void checkUnsupportedInst(TargetRequest* target, IRFunc* func, DiagnosticSink* sink) { @@ -65,8 +25,6 @@ void checkUnsupportedInst(TargetRequest* target, IRFunc* func, DiagnosticSink* s void checkUnsupportedInst(TargetRequest* target, IRModule* module, DiagnosticSink* sink) { - HashSet checkedFuncsForRecursionDetection; - for (auto globalInst : module->getGlobalInsts()) { switch (globalInst->getOp()) @@ -84,8 +42,6 @@ void checkUnsupportedInst(TargetRequest* target, IRModule* module, DiagnosticSin break; } case kIROp_Func: - if (!isCPUTarget(target)) - checkRecursion(checkedFuncsForRecursionDetection, as(globalInst), sink); checkUnsupportedInst(target, as(globalInst), sink); break; case kIROp_Generic: diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index 92c442433a..06c5f005b5 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -9,7 +9,7 @@ #include "slang-ir-autodiff.h" #include "slang-ir-bit-field-accessors.h" #include "slang-ir-check-differentiability.h" -#include "slang-ir-check-recursive-type.h" +#include "slang-ir-check-recursion.h" #include "slang-ir-clone.h" #include "slang-ir-constexpr.h" #include "slang-ir-dce.h"