forked from shader-slang/slang
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fix crash during emitCast of attributed type, allow MaxIters to take …
…linktime const. (shader-slang#5791) * Fix crash during emitCast of attributed type. * Allow [MaxIters] to take link time constants. --------- Co-authored-by: Ellie Hermaszewska <[email protected]>
- Loading branch information
1 parent
71e90a7
commit 051ae8a
Showing
7 changed files
with
93 additions
and
14 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
//TEST:SIMPLE(filecheck=CHECK): -target spirv | ||
// CHECK: OpEntryPoint | ||
|
||
module test; | ||
|
||
public enum class MaterialID : uint { invalid = 0xffffffff }; | ||
|
||
public struct Material : IDifferentiable | ||
{ | ||
float x; | ||
} | ||
|
||
public struct Hit | ||
{ | ||
MaterialID material; | ||
} | ||
|
||
public struct Scene | ||
{ | ||
StructuredBuffer<Material> materials; | ||
RWStructuredBuffer<Material> grads; | ||
|
||
[Differentiable] | ||
Material load(MaterialID id) { return materials[uint(id)]; } | ||
|
||
void accumulate(MaterialID id, Material d) { grads[uint(id)].x += d.x; } | ||
|
||
[Differentiable, BackwardDerivative(_get_material_bwd)] | ||
public Material get_material(MaterialID id) { return load(id); } | ||
|
||
public void _get_material_bwd(MaterialID id, Material d) { accumulate(id, d); } | ||
|
||
[Differentiable] | ||
public Material get_material(Hit hit) { return get_material(hit.material); } | ||
} | ||
|
||
[Differentiable] | ||
float trace(const Scene scene, Hit hit) | ||
{ | ||
Material m = scene.get_material(hit); | ||
return m.x; | ||
} | ||
|
||
|
||
[shader("compute")] | ||
void main( | ||
uniform Scene scene, | ||
uniform StructuredBuffer<uint> input, | ||
uniform RWStructuredBuffer<float> output, | ||
uniform RWStructuredBuffer<float> grads | ||
) | ||
{ | ||
Hit hit; | ||
hit.material = MaterialID(input[0]); | ||
output[0] = trace(scene, hit); | ||
bwd_diff(trace)(scene, hit, grads[0]); | ||
} |
15 changes: 15 additions & 0 deletions
15
tests/language-feature/constants/max-iters-link-time-const.slang
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
//TEST:SIMPLE(filecheck=CHECK): -target spirv | ||
// CHECK: OpEntryPoint | ||
|
||
extern static const int num = 10; | ||
RWStructuredBuffer<float> outputBuffer; | ||
|
||
[numthreads(1,1,1)] | ||
void computeMain() | ||
{ | ||
[MaxIters(num)] | ||
for (int i = 0; i < num; i++) | ||
{ | ||
outputBuffer[i] = i; | ||
} | ||
} |