Skip to content

Commit

Permalink
Synthesize conformance for generic requirements. (shader-slang#5111)
Browse files Browse the repository at this point in the history
* Synthesize conformance for generic requirements.

* Fix.

* Fix build error.

* address code review.
  • Loading branch information
csyonghe authored Sep 19, 2024
1 parent dd3d80e commit 26ca9c5
Show file tree
Hide file tree
Showing 7 changed files with 450 additions and 114 deletions.
403 changes: 293 additions & 110 deletions source/slang/slang-check-decl.cpp

Large diffs are not rendered by default.

7 changes: 7 additions & 0 deletions source/slang/slang-check-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -1743,6 +1743,12 @@ namespace Slang
CallableDecl* synthesized,
List<Expr*>& synArgs);

CallableDecl* synthesizeMethodSignatureForRequirementWitnessInner(
ConformanceCheckingContext* context,
DeclRef<CallableDecl> requiredMemberDeclRef,
List<Expr*>& synArgs,
ThisExpr*& synThis);

CallableDecl* synthesizeMethodSignatureForRequirementWitness(
ConformanceCheckingContext* context,
DeclRef<CallableDecl> requiredMemberDeclRef,
Expand Down Expand Up @@ -1803,6 +1809,7 @@ namespace Slang

bool trySynthesizeSubscriptRequirementWitness(
ConformanceCheckingContext* context,
const LookupResult& lookupResult,
DeclRef<SubscriptDecl> requiredMemberDeclRef,
RefPtr<WitnessTable> witnessTable);

Expand Down
3 changes: 1 addition & 2 deletions source/slang/slang-check-overload.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2674,8 +2674,7 @@ namespace Slang
else if (auto overloadedExpr = as<OverloadedExpr>(baseExpr))
{
// We are referring to a bunch of declarations, each of which might be generic
LookupResult result;
for (auto item : overloadedExpr->lookupResult2.items)
for (auto item : overloadedExpr->lookupResult2)
{
AddGenericOverloadCandidate(item, context);
}
Expand Down
5 changes: 3 additions & 2 deletions source/slang/slang-language-server-completion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ static const char* kDeclKeywords[] = {
"protected", "typedef", "typealias", "uniform", "export", "groupshared",
"extension", "associatedtype", "namespace", "This", "using",
"__generic", "__exported", "import", "enum", "cbuffer", "tbuffer", "func",
"functype"};
"functype", "typename", "each", "expand" };
static const char* kStmtKeywords[] = {
"if", "else", "switch", "case", "default", "return",
"try", "throw", "throws", "catch", "while", "for",
Expand All @@ -35,7 +35,8 @@ static const char* kStmtKeywords[] = {
"__generic", "__exported", "import", "enum", "break", "continue",
"discard", "defer", "cbuffer", "tbuffer", "func", "is",
"as", "nullptr", "none", "true", "false", "functype",
"sizeof", "alignof", "__target_switch", "__intrinsic_asm"};
"sizeof", "alignof", "__target_switch", "__intrinsic_asm",
"each", "expand" };

static const char* hlslSemanticNames[] = {
"register",
Expand Down
21 changes: 21 additions & 0 deletions source/slang/slang-parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1765,6 +1765,27 @@ namespace Slang
expr->baseExpression->accept(this, nullptr);
expr->scope = scope;
}
void visitAppExprBase(AppExprBase* expr)
{
expr->functionExpr->accept(this, nullptr);
for (auto arg : expr->arguments)
arg->accept(this, nullptr);
}
void visitIsTypeExpr(IsTypeExpr* expr)
{
if (expr->typeExpr.exp)
expr->typeExpr.exp->accept(this, nullptr);
}
void visitAsTypeExpr(AsTypeExpr* expr)
{
if (expr->typeExpr)
expr->typeExpr->accept(this, nullptr);
}
void visiSizeOfLikeExpr(SizeOfLikeExpr* expr)
{
if (expr->value)
expr->value->accept(this, nullptr);
}
void visitExpr(Expr* /*expr*/)
{}
};
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
// Test that we allow type conformances whose base interface is generic.

//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-dx11 -compute -output-using-type
//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-vk -compute -output-using-type

interface IStack<let D : int>
{
IStack<D - N> popN<let N : int>();

int get();
}
struct StackImpl<let D : int> : IStack<D>
{
// member 'popN' does not match interface requirement.
StackImpl<D - N> popN<int N>() { return StackImpl<D - N>(); }

int get() { return D; }
}

int helper<int n, T : IStack<n>>(T stack)
{
return stack.popN<2>().get();
}

//TEST_INPUT: set outputBuffer = out ubuffer(data=[0 0 0 0], stride=4);
RWStructuredBuffer<int> outputBuffer;

[numthreads(1, 1, 1)]
void computeMain()
{
StackImpl<5> obj = StackImpl<5>();

// CHECK: 3
outputBuffer[0] = helper(obj);
}
90 changes: 90 additions & 0 deletions tests/language-feature/interfaces/generic-requirement-synth.slang
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
// Test that we can synthesize requirements for generic methods.

//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-dx11 -compute -output-using-type
//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-vk -compute -output-using-type

interface IBase
{
static float get();
}
interface IBar : IBase
{
float derivedMethod();
}

struct Bar : IBar
{
static float get() { return 1.0f; }
float derivedMethod() { return 2.0f; }
}

interface ITestInterface<Real : IFloat>
{
Real sample<T : IBar>(T t);

__init<T : IBar>(T t);

__generic<T : IBar>
__subscript(T t)->Real { get; }
}

struct TestInterfaceImpl<Real : IFloat> : ITestInterface<Real>
{
// The signature of this sample method is different from the one in the
// interface. However, we should be able to form a call into this method
// from the synthesized implementation matching the interface definition,
// so the conformance should hold.
Real sample<T : IBase>(T t)
{
return x + Real(T.get());
}

// Test the same thing for constructors.
__init<T : IBase>(T t)
{
x = Real(T.get());
}

// Test the same thing for subscript operators.
__generic<T : IBase>
__subscript(T t)->Real { get { return x + Real(T.get()); } }
Real x;
}

float test(ITestInterface<float> obj)
{
Bar b = {};
return obj.sample<Bar>(b);
}

float test1(ITestInterface<float> obj)
{
Bar b = {};
return obj[b];
}

float test2<T:ITestInterface<float>>()
{
Bar b = {};
T obj = T(b);
return obj[b];
}

//TEST_INPUT: set outputBuffer = out ubuffer(data=[0 0 0 0], stride=4);
RWStructuredBuffer<int> outputBuffer;

[numthreads(1, 1, 1)]
void computeMain()
{
TestInterfaceImpl<float> obj;
obj.x = 1.0f;

// CHECK: 2
outputBuffer[0] = int(test(obj));

// CHECK: 2
outputBuffer[1] = int(test1(obj));

// CHECK: 2
outputBuffer[3] = int(test2<TestInterfaceImpl<float>>());
}

0 comments on commit 26ca9c5

Please sign in to comment.