Skip to content

Commit

Permalink
Fix crash during emitCast of attributed type, allow MaxIters to take …
Browse files Browse the repository at this point in the history
…linktime const. (shader-slang#5791)

* Fix crash during emitCast of attributed type.

* Allow [MaxIters] to take link time constants.

---------

Co-authored-by: Ellie Hermaszewska <[email protected]>
  • Loading branch information
csyonghe and expipiplus1 authored Dec 9, 2024
1 parent 71e90a7 commit 051ae8a
Show file tree
Hide file tree
Showing 7 changed files with 93 additions and 14 deletions.
2 changes: 1 addition & 1 deletion source/slang/slang-ast-modifier.h
Original file line number Diff line number Diff line change
Expand Up @@ -744,7 +744,7 @@ class MaxItersAttribute : public Attribute
{
SLANG_AST_CLASS(MaxItersAttribute)

int32_t value = 0;
IntVal* value = 0;
};

// An inferred max iteration count on a loop.
Expand Down
6 changes: 1 addition & 5 deletions source/slang/slang-check-modifier.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -720,11 +720,7 @@ Modifier* SemanticsVisitor::validateAttribute(
}
else
{
auto cint = checkConstantIntVal(attr->args[0]);
if (cint)
{
maxItersAttrs->value = (int32_t)cint->getValue();
}
maxItersAttrs->value = checkLinkTimeConstantIntVal(attr->args[0]);
}
}
else if (const auto userDefAttr = as<UserDefinedAttribute>(attr))
Expand Down
9 changes: 7 additions & 2 deletions source/slang/slang-ir-insts.h
Original file line number Diff line number Diff line change
Expand Up @@ -4637,9 +4637,14 @@ struct IRBuilder
getIntValue(getIntType(), IRIntegerValue(mode)));
}

void addLoopMaxItersDecoration(IRInst* value, IntegerLiteralValue iters)
void addLoopMaxItersDecoration(IRInst* value, IRIntegerValue iters)
{
addDecoration(value, kIROp_LoopMaxItersDecoration, getIntValue(getIntType(), iters));
addDecoration(value, kIROp_LoopMaxItersDecoration, getIntValue(iters));
}

void addLoopMaxItersDecoration(IRInst* value, IRInst* iters)
{
addDecoration(value, kIROp_LoopMaxItersDecoration, iters);
}

void addLoopForceUnrollDecoration(IRInst* value, IntegerLiteralValue iters)
Expand Down
2 changes: 2 additions & 0 deletions source/slang/slang-ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3890,6 +3890,8 @@ enum class TypeCastStyle
};
static TypeCastStyle _getTypeStyleId(IRType* type)
{
type = (IRType*)unwrapAttributedType(type);

if (auto vectorType = as<IRVectorType>(type))
{
return _getTypeStyleId(vectorType->getElementType());
Expand Down
16 changes: 10 additions & 6 deletions source/slang/slang-lower-to-ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5938,7 +5938,8 @@ struct StmtLoweringVisitor : StmtVisitor<StmtLoweringVisitor>

if (auto maxItersAttr = stmt->findModifier<MaxItersAttribute>())
{
getBuilder()->addLoopMaxItersDecoration(inst, maxItersAttr->value);
auto iters = lowerVal(context, maxItersAttr->value);
getBuilder()->addLoopMaxItersDecoration(inst, getSimpleVal(context, iters));
}
else if (auto inferredMaxItersAttr = stmt->findModifier<InferredMaxItersAttribute>())
{
Expand Down Expand Up @@ -6028,12 +6029,15 @@ struct StmtLoweringVisitor : StmtVisitor<StmtLoweringVisitor>
{
if (auto maxIters = stmt->findModifier<MaxItersAttribute>())
{
if (inferredMaxIters->value < maxIters->value)
if (auto constIntVal = as<ConstantIntVal>(maxIters->value))
{
context->getSink()->diagnose(
maxIters,
Diagnostics::forLoopTerminatesInFewerIterationsThanMaxIters,
inferredMaxIters->value);
if (inferredMaxIters->value < constIntVal->getValue())
{
context->getSink()->diagnose(
maxIters,
Diagnostics::forLoopTerminatesInFewerIterationsThanMaxIters,
inferredMaxIters->value);
}
}
}
}
Expand Down
57 changes: 57 additions & 0 deletions tests/bugs/gh-5781.slang
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
//TEST:SIMPLE(filecheck=CHECK): -target spirv
// CHECK: OpEntryPoint

module test;

public enum class MaterialID : uint { invalid = 0xffffffff };

public struct Material : IDifferentiable
{
float x;
}

public struct Hit
{
MaterialID material;
}

public struct Scene
{
StructuredBuffer<Material> materials;
RWStructuredBuffer<Material> grads;

[Differentiable]
Material load(MaterialID id) { return materials[uint(id)]; }

void accumulate(MaterialID id, Material d) { grads[uint(id)].x += d.x; }

[Differentiable, BackwardDerivative(_get_material_bwd)]
public Material get_material(MaterialID id) { return load(id); }

public void _get_material_bwd(MaterialID id, Material d) { accumulate(id, d); }

[Differentiable]
public Material get_material(Hit hit) { return get_material(hit.material); }
}

[Differentiable]
float trace(const Scene scene, Hit hit)
{
Material m = scene.get_material(hit);
return m.x;
}


[shader("compute")]
void main(
uniform Scene scene,
uniform StructuredBuffer<uint> input,
uniform RWStructuredBuffer<float> output,
uniform RWStructuredBuffer<float> grads
)
{
Hit hit;
hit.material = MaterialID(input[0]);
output[0] = trace(scene, hit);
bwd_diff(trace)(scene, hit, grads[0]);
}
15 changes: 15 additions & 0 deletions tests/language-feature/constants/max-iters-link-time-const.slang
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
//TEST:SIMPLE(filecheck=CHECK): -target spirv
// CHECK: OpEntryPoint

extern static const int num = 10;
RWStructuredBuffer<float> outputBuffer;

[numthreads(1,1,1)]
void computeMain()
{
[MaxIters(num)]
for (int i = 0; i < num; i++)
{
outputBuffer[i] = i;
}
}

0 comments on commit 051ae8a

Please sign in to comment.