diff --git a/source/slang/slang-ast-modifier.h b/source/slang/slang-ast-modifier.h index ee9b55334b..863ffaef23 100644 --- a/source/slang/slang-ast-modifier.h +++ b/source/slang/slang-ast-modifier.h @@ -744,7 +744,7 @@ class MaxItersAttribute : public Attribute { SLANG_AST_CLASS(MaxItersAttribute) - int32_t value = 0; + IntVal* value = 0; }; // An inferred max iteration count on a loop. diff --git a/source/slang/slang-check-modifier.cpp b/source/slang/slang-check-modifier.cpp index 0794279a74..a1e0f78764 100644 --- a/source/slang/slang-check-modifier.cpp +++ b/source/slang/slang-check-modifier.cpp @@ -720,11 +720,7 @@ Modifier* SemanticsVisitor::validateAttribute( } else { - auto cint = checkConstantIntVal(attr->args[0]); - if (cint) - { - maxItersAttrs->value = (int32_t)cint->getValue(); - } + maxItersAttrs->value = checkLinkTimeConstantIntVal(attr->args[0]); } } else if (const auto userDefAttr = as(attr)) diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index 4c72727554..829e725757 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -4637,9 +4637,14 @@ struct IRBuilder getIntValue(getIntType(), IRIntegerValue(mode))); } - void addLoopMaxItersDecoration(IRInst* value, IntegerLiteralValue iters) + void addLoopMaxItersDecoration(IRInst* value, IRIntegerValue iters) { - addDecoration(value, kIROp_LoopMaxItersDecoration, getIntValue(getIntType(), iters)); + addDecoration(value, kIROp_LoopMaxItersDecoration, getIntValue(iters)); + } + + void addLoopMaxItersDecoration(IRInst* value, IRInst* iters) + { + addDecoration(value, kIROp_LoopMaxItersDecoration, iters); } void addLoopForceUnrollDecoration(IRInst* value, IntegerLiteralValue iters) diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index f6c662a98d..29fbcc3c95 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -3890,6 +3890,8 @@ enum class TypeCastStyle }; static TypeCastStyle _getTypeStyleId(IRType* type) { + type = (IRType*)unwrapAttributedType(type); + if (auto vectorType = as(type)) { return _getTypeStyleId(vectorType->getElementType()); diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index 75cf421af0..ce6f8cb427 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -5938,7 +5938,8 @@ struct StmtLoweringVisitor : StmtVisitor if (auto maxItersAttr = stmt->findModifier()) { - getBuilder()->addLoopMaxItersDecoration(inst, maxItersAttr->value); + auto iters = lowerVal(context, maxItersAttr->value); + getBuilder()->addLoopMaxItersDecoration(inst, getSimpleVal(context, iters)); } else if (auto inferredMaxItersAttr = stmt->findModifier()) { @@ -6028,12 +6029,15 @@ struct StmtLoweringVisitor : StmtVisitor { if (auto maxIters = stmt->findModifier()) { - if (inferredMaxIters->value < maxIters->value) + if (auto constIntVal = as(maxIters->value)) { - context->getSink()->diagnose( - maxIters, - Diagnostics::forLoopTerminatesInFewerIterationsThanMaxIters, - inferredMaxIters->value); + if (inferredMaxIters->value < constIntVal->getValue()) + { + context->getSink()->diagnose( + maxIters, + Diagnostics::forLoopTerminatesInFewerIterationsThanMaxIters, + inferredMaxIters->value); + } } } } diff --git a/tests/bugs/gh-5781.slang b/tests/bugs/gh-5781.slang new file mode 100644 index 0000000000..33456f5001 --- /dev/null +++ b/tests/bugs/gh-5781.slang @@ -0,0 +1,57 @@ +//TEST:SIMPLE(filecheck=CHECK): -target spirv +// CHECK: OpEntryPoint + +module test; + +public enum class MaterialID : uint { invalid = 0xffffffff }; + +public struct Material : IDifferentiable +{ + float x; +} + +public struct Hit +{ + MaterialID material; +} + +public struct Scene +{ + StructuredBuffer materials; + RWStructuredBuffer grads; + + [Differentiable] + Material load(MaterialID id) { return materials[uint(id)]; } + + void accumulate(MaterialID id, Material d) { grads[uint(id)].x += d.x; } + + [Differentiable, BackwardDerivative(_get_material_bwd)] + public Material get_material(MaterialID id) { return load(id); } + + public void _get_material_bwd(MaterialID id, Material d) { accumulate(id, d); } + + [Differentiable] + public Material get_material(Hit hit) { return get_material(hit.material); } +} + +[Differentiable] +float trace(const Scene scene, Hit hit) +{ + Material m = scene.get_material(hit); + return m.x; +} + + +[shader("compute")] +void main( + uniform Scene scene, + uniform StructuredBuffer input, + uniform RWStructuredBuffer output, + uniform RWStructuredBuffer grads +) +{ + Hit hit; + hit.material = MaterialID(input[0]); + output[0] = trace(scene, hit); + bwd_diff(trace)(scene, hit, grads[0]); +} \ No newline at end of file diff --git a/tests/language-feature/constants/max-iters-link-time-const.slang b/tests/language-feature/constants/max-iters-link-time-const.slang new file mode 100644 index 0000000000..cf1ccbbd1b --- /dev/null +++ b/tests/language-feature/constants/max-iters-link-time-const.slang @@ -0,0 +1,15 @@ +//TEST:SIMPLE(filecheck=CHECK): -target spirv +// CHECK: OpEntryPoint + +extern static const int num = 10; +RWStructuredBuffer outputBuffer; + +[numthreads(1,1,1)] +void computeMain() +{ + [MaxIters(num)] + for (int i = 0; i < num; i++) + { + outputBuffer[i] = i; + } +}