Skip to content

Commit

Permalink
Support generic constraints that are dependent on another generic par…
Browse files Browse the repository at this point in the history
  • Loading branch information
csyonghe authored May 2, 2024
1 parent 7ef980f commit 1863fe1
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 6 deletions.
29 changes: 27 additions & 2 deletions source/slang/slang-ir-link.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -746,16 +746,41 @@ void cloneGlobalValueWithCodeCommon(
{
IRBlock* ob = originalValue->getFirstBlock();
IRBlock* cb = clonedValue->getFirstBlock();
struct ParamCloneInfo
{
IRParam* originalParam;
IRParam* clonedParam;
};
ShortList<ParamCloneInfo> paramCloneInfos;
while (ob)
{
SLANG_ASSERT(cb);

builder->setInsertInto(cb);
for (auto oi = ob->getFirstInst(); oi; oi = oi->getNextInst())
{
cloneInst(context, builder, oi);
if (oi->getOp() == kIROp_Param)
{
// Params may have forward references in its type and
// decorations, so we just create a placeholder for it
// in this first pass.
IRParam* clonedParam = builder->emitParam(nullptr);
registerClonedValue(context, clonedParam, oi);
paramCloneInfos.add({ (IRParam*)oi, clonedParam });
}
else
{
cloneInst(context, builder, oi);
}
}
// Clone the type and decorations of parameters after all instructs in the block
// have been cloned.
for (auto param : paramCloneInfos)
{
builder->setInsertInto(param.clonedParam);
param.clonedParam->setFullType((IRType*)cloneValue(context, param.originalParam->getFullType()));
cloneDecorations(context, param.clonedParam, param.originalParam);
}

ob = ob->getNextBlock();
cb = cb->getNextBlock();
}
Expand Down
11 changes: 7 additions & 4 deletions source/slang/slang-lower-to-ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8823,7 +8823,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
IRGenContext* subContext,
GenericTypeConstraintDecl* constraintDecl)
{
auto supType = lowerType(context, constraintDecl->sup.type);
auto supType = lowerType(subContext, constraintDecl->sup.type);
auto value = emitGenericConstraintValue(subContext, constraintDecl, supType);
subContext->setValue(constraintDecl, LoweredValInfo::simple(value));
}
Expand Down Expand Up @@ -8972,9 +8972,11 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
auto operand = value->getOperand(i);
markInstsToClone(valuesToClone, parentBlock, operand);
}
if (value->getFullType())
markInstsToClone(valuesToClone, parentBlock, value->getFullType());
for (auto child : value->getDecorationsAndChildren())
markInstsToClone(valuesToClone, parentBlock, child);
}
for (auto child : value->getChildren())
markInstsToClone(valuesToClone, parentBlock, child);
auto parent = parentBlock->getParent();
while (parent && parent != parentBlock)
{
Expand Down Expand Up @@ -9025,7 +9027,8 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
markInstsToClone(valuesToClone, parentGeneric->getFirstBlock(), returnType);
// For Function Types, we always clone all generic parameters regardless of whether
// the generic parameter appears in the function signature or not.
if (returnType->getOp() == kIROp_FuncType)
if (returnType->getOp() == kIROp_FuncType ||
returnType->getOp() == kIROp_Generic)
{
for (auto genericParam : parentGeneric->getParams())
{
Expand Down
64 changes: 64 additions & 0 deletions tests/language-feature/generics/generic-witness-derived.slang
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -shaderobj -output-using-type
//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -vk -shaderobj -output-using-type

// Test that we can compile a generic function with a generic type constraint that is dependent on an
// outer generic type parameter.

namespace ns{

public interface IBinaryElementWiseFunction<T>
{
public static T call(const in T lhs, const in T rhs);
}
public struct AddOp<T : IArithmetic> : IBinaryElementWiseFunction<T>
{
public static T call(const in T lhs, const in T rhs)
{
return lhs + rhs;
}
}
public struct BinaryElementWiseInputData<T : IArithmetic>
{
T lhs;
T rhs;

// Note: `U` is constrainted by `IBinaryElementWiseFunction<T>`, which is dependent on `T`,
// that is another generic type parameter defined on the outer type.
// This eventually leads to a IRGeneric where one param has a type that is dependent on
// another param.
// In this case, the IR for `test` after generic flattening will be:
// ```
// %g_test = IRGeneric
// {
// IRBlock
// {
// %T = IRParam : Type;
// %T_w = IRParam : IRWitnessTableType<IArithmetic>;
// %U = IRParam : Type;
// %U_w = IRRaram : IRWitnessTableType<%s>; // note that the type here is a forward reference to %s
// %s = specialize(%IBinaryElementWiseFunction, %T) // %s is dependent on %T.
// ...
// }
// }
//
public T test<U : IBinaryElementWiseFunction<T>>(U x)
{
return x.call(lhs ,rhs);
}
}
}


//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer
RWStructuredBuffer<int> outputBuffer;

[shader("compute")]
[numthreads(1,1,1)]
void computeMain(uint3 threadId: SV_DispatchThreadID)
{
ns::BinaryElementWiseInputData<int> cb;
cb.lhs = threadId.x + 1;
cb.rhs = 2;
// CHECK: 3
outputBuffer[0] = cb.test(ns::AddOp<int>());
}

0 comments on commit 1863fe1

Please sign in to comment.