Skip to content

Commit

Permalink
Fix unzipping logic for inout non-diff parameters and adjust tests (s…
Browse files Browse the repository at this point in the history
…hader-slang#4090)

* Fix unzipping logic for inout non-diff parameters and adjust tests

+ Removed `-g0` from `struct-this-parameter.slang` test. Works correctly with the new unzipping logic.
+ Removed `-g0` from `was/warped-sampling-1d.slang` test. Works correctly with DX12 & CS_5_1. CS_5_0 appears to run into an FXC compiler bug with detecting infinite loops where there don't appear to be any.

* Update slang-ir-autodiff-unzip.h

* Update warped-sampling-1d.slang
  • Loading branch information
saipraveenb25 authored May 2, 2024
1 parent 6b30957 commit 7ef980f
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 6 deletions.
28 changes: 25 additions & 3 deletions source/slang/slang-ir-autodiff-unzip.h
Original file line number Diff line number Diff line change
Expand Up @@ -372,10 +372,32 @@ struct DiffUnzipPass
}
else
{
// For non differentiable arguments, we can simply pass the argument as is
// if this isn't a `out` parameter, in which case it is removed from propagate call.
if (!as<IROutType>(arg->getDataType()))
if (auto inOutType = as<IRInOutType>(resolvedPrimalFuncType->getParamType(ii)))
{
// For 'inout' parameter we need to create a temp var to hold the value
// before the primal call. This logic is similar to the 'inout' case for differentiable params
// only we don't need to deal with pair types.
//
auto tempPrimalVar = primalBuilder->emitVar(as<IRPtrTypeBase>(arg->getDataType())->getValueType());

auto storeUse = findUniqueStoredVal(cast<IRVar>(arg));
auto storeInst = cast<IRStore>(storeUse->getUser());
auto storedVal = storeInst->getVal();

primalBuilder->emitStore(tempPrimalVar, storedVal);

diffArgs.add(tempPrimalVar);
}
else
{
// For pure 'in' type. Simply re-use the original argument inst.
//
// For 'out' type parameters, it doesn't really matter what we pass in here, since
// the tranposition logic will discard the argument anyway (we'll pass in the old arg,
// just to keep the number of arguments consistent)
//
diffArgs.add(arg);
}
}
}

Expand Down
4 changes: 2 additions & 2 deletions tests/autodiff/struct-this-parameter.slang
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type -g0
//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type -g0
//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type
//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type

//TEST_INPUT:ubuffer(data=[0 0 0], stride=4):out,name=outputBuffer
RWStructuredBuffer<float> outputBuffer;
Expand Down
2 changes: 1 addition & 1 deletion tests/autodiff/was/warped-sampling-1d.slang
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type -g0
//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type -profile cs_5_1 -dx12

//TEST_INPUT:ubuffer(data=[0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0], stride=4):out,name=endpointDifferentialBuffer
RWStructuredBuffer<float> endpointDifferentialBuffer;
Expand Down

0 comments on commit 7ef980f

Please sign in to comment.