Skip to content

Commit

Permalink
Allow pointers to existential values. (shader-slang#5793)
Browse files Browse the repository at this point in the history
* Fix pointer offset logic and add executable tests.

* Fix.

* Fix test.

* Add existential ptr test.

* Allow pointers to existential values.

* Fix.

* Fix.

---------

Co-authored-by: Ellie Hermaszewska <[email protected]>
  • Loading branch information
csyonghe and expipiplus1 authored Dec 9, 2024
1 parent 051ae8a commit 09a9d67
Show file tree
Hide file tree
Showing 12 changed files with 215 additions and 43 deletions.
11 changes: 11 additions & 0 deletions source/slang/slang-check-decl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3105,6 +3105,17 @@ Type* unwrapArrayType(Type* type)
}
}

Type* unwrapModifiedType(Type* type)
{
for (;;)
{
if (auto modType = as<ModifiedType>(type))
type = modType->getBase();
else
return type;
}
}

void discoverExtensionDecls(List<ExtensionDecl*>& decls, Decl* parent)
{
if (auto extDecl = as<ExtensionDecl>(parent))
Expand Down
47 changes: 35 additions & 12 deletions source/slang/slang-check-expr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2307,7 +2307,10 @@ Expr* SemanticsVisitor::CheckSimpleSubscriptExpr(IndexExpr* subscriptExpr, Type*
Expr* SemanticsExprVisitor::visitIndexExpr(IndexExpr* subscriptExpr)
{
bool needDeref = false;
auto baseExpr = checkBaseForMemberExpr(subscriptExpr->baseExpression, needDeref);
auto baseExpr = checkBaseForMemberExpr(
subscriptExpr->baseExpression,
CheckBaseContext::Subscript,
needDeref);

// If the base expression is a type, it means that this is an array declaration,
// then we should disable short-circuit in case there is logical expression in
Expand Down Expand Up @@ -2951,7 +2954,10 @@ Expr* SemanticsExprVisitor::visitInvokeExpr(InvokeExpr* expr)
auto operatorName = getName("()");

bool needDeref = false;
expr->functionExpr = maybeInsertImplicitOpForMemberBase(expr->functionExpr, needDeref);
expr->functionExpr = maybeInsertImplicitOpForMemberBase(
expr->functionExpr,
CheckBaseContext::Member,
needDeref);

LookupResult lookupResult = lookUpMember(
m_astBuilder,
Expand Down Expand Up @@ -4060,27 +4066,36 @@ void SemanticsExprVisitor::maybeCheckKnownBuiltinInvocation(Expr* invokeExpr)
}
}

Expr* SemanticsVisitor::MaybeDereference(Expr* inExpr)
Expr* SemanticsVisitor::maybeDereference(Expr* inExpr, CheckBaseContext checkBaseContext)
{
Expr* expr = inExpr;
for (;;)
{
auto baseType = expr->type;
QualType elementType;
if (auto pointerLikeType = as<PointerLikeType>(baseType))
{
auto elementType = QualType(pointerLikeType->getElementType());
elementType = QualType(pointerLikeType->getElementType());
elementType.isLeftValue = baseType.isLeftValue;
elementType.hasReadOnlyOnTarget = baseType.hasReadOnlyOnTarget;
elementType.isWriteOnly = baseType.isWriteOnly;

}
else if (auto ptrType = as<PtrType>(baseType))
{
if (checkBaseContext == CheckBaseContext::Subscript)
return expr;
elementType = QualType(ptrType->getValueType());
elementType.isLeftValue = true;
}
if (elementType.type)
{
auto derefExpr = m_astBuilder->create<DerefExpr>();
derefExpr->base = expr;
derefExpr->type = elementType;

expr = derefExpr;
continue;
}

// Default case: just use the expression as-is
return expr;
}
Expand Down Expand Up @@ -4751,7 +4766,7 @@ Expr* SemanticsExprVisitor::visitStaticMemberExpr(StaticMemberExpr* expr)
expr->baseExpression = CheckTerm(expr->baseExpression);

// Not sure this is needed -> but guess someone could do
expr->baseExpression = MaybeDereference(expr->baseExpression);
expr->baseExpression = maybeDereference(expr->baseExpression, CheckBaseContext::Member);

// If the base of the member lookup has an interface type
// *without* a suitable this-type substitution, then we are
Expand Down Expand Up @@ -4779,9 +4794,12 @@ Expr* SemanticsVisitor::lookupMemberResultFailure(
return expr;
}

Expr* SemanticsVisitor::maybeInsertImplicitOpForMemberBase(Expr* baseExpr, bool& outNeedDeref)
Expr* SemanticsVisitor::maybeInsertImplicitOpForMemberBase(
Expr* baseExpr,
CheckBaseContext checkBaseContext,
bool& outNeedDeref)
{
auto derefExpr = MaybeDereference(baseExpr);
auto derefExpr = maybeDereference(baseExpr, checkBaseContext);

if (derefExpr != baseExpr)
outNeedDeref = true;
Expand Down Expand Up @@ -4834,11 +4852,15 @@ Expr* SemanticsVisitor::maybeInsertImplicitOpForMemberBase(Expr* baseExpr, bool&
return baseExpr;
}

Expr* SemanticsVisitor::checkBaseForMemberExpr(Expr* inBaseExpr, bool& outNeedDeref)
Expr* SemanticsVisitor::checkBaseForMemberExpr(
Expr* inBaseExpr,
CheckBaseContext checkBaseContext,
bool& outNeedDeref)
{
auto baseExpr = inBaseExpr;
baseExpr = CheckTerm(baseExpr);
return maybeInsertImplicitOpForMemberBase(baseExpr, outNeedDeref);

return maybeInsertImplicitOpForMemberBase(baseExpr, checkBaseContext, outNeedDeref);
}

Expr* SemanticsVisitor::checkGeneralMemberLookupExpr(MemberExpr* expr, Type* baseType)
Expand All @@ -4861,7 +4883,8 @@ Expr* SemanticsVisitor::checkGeneralMemberLookupExpr(MemberExpr* expr, Type* bas
Expr* SemanticsExprVisitor::visitMemberExpr(MemberExpr* expr)
{
bool needDeref = false;
expr->baseExpression = checkBaseForMemberExpr(expr->baseExpression, needDeref);
expr->baseExpression =
checkBaseForMemberExpr(expr->baseExpression, CheckBaseContext::Member, needDeref);

if (!needDeref && as<DerefMemberExpr>(expr) && !as<PtrType>(expr->baseExpression->type))
{
Expand Down
19 changes: 15 additions & 4 deletions source/slang/slang-check-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -2654,8 +2654,6 @@ struct SemanticsVisitor : public SemanticsContext
//
//

Expr* MaybeDereference(Expr* inExpr);

Expr* CheckMatrixSwizzleExpr(
MemberExpr* memberRefExpr,
Type* baseElementType,
Expand Down Expand Up @@ -2696,11 +2694,24 @@ struct SemanticsVisitor : public SemanticsContext

/// Perform checking operations required for the "base" expression of a member-reference like
/// `base.someField`
Expr* checkBaseForMemberExpr(Expr* baseExpr, bool& outNeedDeref);
enum class CheckBaseContext
{
Member,
Subscript,
};
Expr* checkBaseForMemberExpr(
Expr* baseExpr,
CheckBaseContext checkBaseContext,
bool& outNeedDeref);

Expr* maybeDereference(Expr* inExpr, CheckBaseContext checkBaseContext);

/// Prepare baseExpr for use as the base of a member expr.
/// This include inserting implicit open-existential operations as needed.
Expr* maybeInsertImplicitOpForMemberBase(Expr* baseExpr, bool& outNeedDeref);
Expr* maybeInsertImplicitOpForMemberBase(
Expr* baseExpr,
CheckBaseContext checkBaseContext,
bool& outNeedDeref);

Expr* lookupMemberResultFailure(
DeclRefExpr* expr,
Expand Down
5 changes: 3 additions & 2 deletions source/slang/slang-check-type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -216,9 +216,10 @@ bool isManagedType(Type* type)
{
if (auto declRefValueType = as<DeclRefType>(type))
{
if (as<ClassDecl>(declRefValueType->getDeclRef().getDecl()))
auto decl = declRefValueType->getDeclRef().getDecl();
if (as<ClassDecl>(decl))
return true;
if (as<InterfaceDecl>(declRefValueType->getDeclRef().getDecl()))
if (as<InterfaceDecl>(decl) && decl->findModifier<ComInterfaceAttribute>())
return true;
}
return false;
Expand Down
1 change: 1 addition & 0 deletions source/slang/slang-check.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ bool isFromCoreModule(Decl* decl);
void registerBuiltinDecls(Session* session, Decl* decl);

Type* unwrapArrayType(Type* type);
Type* unwrapModifiedType(Type* type);

OrderedDictionary<GenericTypeParamDeclBase*, List<Type*>> getCanonicalGenericConstraints(
ASTBuilder* builder,
Expand Down
22 changes: 0 additions & 22 deletions source/slang/slang-ir-lower-buffer-element-type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -901,28 +901,6 @@ struct LoweredElementTypeContext
{
builder.setInsertBefore(ptrVal);
auto newArrayPtrVal = fieldAddr->getBase();
// Is base a pointer to an empty struct? If so, don't offset it.
// For example, if the user has written:
// ```
// struct S {int arr[]};
// uniform S* p;
// void test() { p->arr[1]; }
// ```
// Then `S` will become an empty struct after we remove `arr[]`.
// And `p` will be come a `void*`.
// We don't want to offset `p` to `p+1` to get the starting address of
// the array in this case.
IRSizeAndAlignment parentStructSize = {};
getNaturalSizeAndAlignment(
target->getOptionSet(),
tryGetPointedToType(&builder, fieldAddr->getBase()->getDataType()),
&parentStructSize);
if (parentStructSize.size != 0)
{
newArrayPtrVal = builder.emitGetOffsetPtr(
fieldAddr->getBase(),
builder.getIntValue(builder.getIntType(), 1));
}
auto loweredInnerType =
getLoweredTypeInfo(unsizedArrayType->getElementType(), layoutRules);

Expand Down
1 change: 0 additions & 1 deletion tests/bugs/gh-3825.slang
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ float4 fragment(): SV_Target
}

// CHECK: OpDecorate %_ptr_PhysicalStorageBuffer_Descriptors_natural ArrayStride 4
// CHECK: %{{.*}} = OpPtrAccessChain %_ptr_PhysicalStorageBuffer_Descriptors_natural %{{.*}} %int_1
// CHECK: OpBitcast %ulong
// CHECK: OpIAdd %ulong %{{.*}} %ulong_4
// CHECK: OpBitcast %_ptr_PhysicalStorageBuffer
37 changes: 37 additions & 0 deletions tests/spirv/existential-ptr.slang
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
//TEST:COMPARE_COMPUTE(filecheck-buffer=CHECK): -vk -emit-spirv-directly -output-using-type
//DISABLED_TEST:COMPARE_COMPUTE(filecheck-buffer=CHECK): -wgpu
//DISABLED_TEST:COMPARE_COMPUTE(filecheck-buffer=CHECK): -d3d12
//DISABLED_TEST:COMPARE_COMPUTE(filecheck-buffer=CHECK): -d3d11
//DISABLED_TEST:COMPARE_COMPUTE(filecheck-buffer=CHECK): -metal

interface IFoo
{
int getVal();
}

struct Foo : IFoo
{
int val;
int getVal() { return val; }
}

struct Bar : IFoo
{
float val;
int getVal() { return (int)val + 1; }
}

//TEST_INPUT: set pFoo = ubuffer(data=[0 0 2 0 2.0f], stride=4);
//TEST_INPUT: type_conformance Foo:IFoo = 1;
//TEST_INPUT: type_conformance Bar:IFoo = 2;
uniform IFoo* pFoo;

//TEST_INPUT: set outputBuffer = out ubuffer(data=[0 0 0 0], stride=4);
RWStructuredBuffer<float> outputBuffer;

[numthreads(1,1,1)]
void computeMain()
{
// CHECK: 3.0
outputBuffer[0] = pFoo->getVal();
}
29 changes: 29 additions & 0 deletions tests/spirv/ptr-member-func.slang
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
//TEST:COMPARE_COMPUTE(filecheck-buffer=CHECK): -vk -emit-spirv-directly
//DISABLED_TEST: COMPARE_COMPUTE(filecheck-buffer=CHECK): -wgpu
//DISABLED_TEST: COMPARE_COMPUTE(filecheck-buffer=CHECK): -d3d11
//DISABLED_TEST: COMPARE_COMPUTE(filecheck-buffer=CHECK): -d3d12
//DISABLED_TEST: COMPARE_COMPUTE(filecheck-buffer=CHECK): -metal

struct Obj
{
int val;

[mutating]
void addOne() { val++; }

int getValPlusOne() { return val + 1; }
}

//TEST_INPUT: set pObj = ubuffer(data=[2 0 0 0], stride=4);
uniform Obj* pObj;

//TEST_INPUT: set outputBuffer = out ubuffer(data=[0 0 0 0],stride=4);
uniform RWStructuredBuffer<uint> outputBuffer;

[numthreads(1,1,1)]
void computeMain()
{
pObj->addOne();
// CHECK: 4
outputBuffer[0] = pObj->getValPlusOne();
}
29 changes: 29 additions & 0 deletions tests/spirv/ptr-unsized-array-3.slang
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
//TEST:COMPARE_COMPUTE(filecheck-buffer=CHECK): -vk -emit-spirv-directly
//DISABLED_TEST: COMPARE_COMPUTE(filecheck-buffer=CHECK): -wgpu
//DISABLED_TEST: COMPARE_COMPUTE(filecheck-buffer=CHECK): -d3d11
//DISABLED_TEST: COMPARE_COMPUTE(filecheck-buffer=CHECK): -d3d12
//DISABLED_TEST: COMPARE_COMPUTE(filecheck-buffer=CHECK): -metal

// Test a pointer to a struct with a trailing unsized array.

struct MeshStorage {
int foo;
uint64_t QuadData[];
};

//TEST_INPUT: set pStorage = ubuffer(data=[1 2 3 4 5 6 7 8],stride=4);
uniform MeshStorage* pStorage;

//TEST_INPUT: set outputBuffer = out ubuffer(data=[0 0 0 0],stride=4);
uniform RWStructuredBuffer<uint> outputBuffer;

[numthreads(1,1,1)]
void computeMain()
{
// CHECK: 5
// CHECK: 6
// CHECK: 1
outputBuffer[0] = (int)(pStorage.QuadData[1]&0xFFFFFFFF);
outputBuffer[1] = (int)(pStorage.QuadData[1]>>32);
outputBuffer[2] = pStorage.foo;
}
25 changes: 25 additions & 0 deletions tests/spirv/ptr-unsized-array-4.slang
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
//TEST:COMPARE_COMPUTE(filecheck-buffer=CHECK): -vk -emit-spirv-directly
//DISABLED_TEST: COMPARE_COMPUTE(filecheck-buffer=CHECK): -wgpu
//DISABLED_TEST: COMPARE_COMPUTE(filecheck-buffer=CHECK): -d3d11
//DISABLED_TEST: COMPARE_COMPUTE(filecheck-buffer=CHECK): -d3d12
//DISABLED_TEST: COMPARE_COMPUTE(filecheck-buffer=CHECK): -metal

// Test a pointer to a struct that has only one field and is an unsized array.
struct MeshStorage {
uint64_t QuadData[];
};

//TEST_INPUT: set pStorage = ubuffer(data=[1 2 3 4 5 6 7 8],stride=4);
uniform MeshStorage* pStorage;

//TEST_INPUT: set outputBuffer = out ubuffer(data=[0 0 0 0],stride=4);
uniform RWStructuredBuffer<uint> outputBuffer;

[numthreads(1,1,1)]
void computeMain()
{
// CHECK: 3
// CHECK: 4
outputBuffer[0] = (int)(pStorage.QuadData[1]&0xFFFFFFFF);
outputBuffer[1] = (int)(pStorage.QuadData[1]>>32);
}
Loading

0 comments on commit 09a9d67

Please sign in to comment.