Skip to content

Commit

Permalink
Fix spirv codegen for pointer to empty structs. (shader-slang#5355)
Browse files Browse the repository at this point in the history
  • Loading branch information
csyonghe authored Oct 21, 2024
1 parent 20fa42e commit 3e84726
Show file tree
Hide file tree
Showing 9 changed files with 129 additions and 26 deletions.
6 changes: 3 additions & 3 deletions source/slang/slang-compiler-tu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ namespace Slang
applySettingsToDiagnosticSink(&sink, &sink, linkage->m_optionSet);
applySettingsToDiagnosticSink(&sink, &sink, m_optionSet);

TargetRequest* targetReq = new TargetRequest(linkage, targetEnum);
RefPtr<TargetRequest> targetReq = new TargetRequest(linkage, targetEnum);

List<RefPtr<ComponentType>> allComponentTypes;
allComponentTypes.add(this); // Add Module as a component type
Expand Down Expand Up @@ -206,8 +206,8 @@ namespace Slang
}
}

ISlangBlob* blob;
outArtifact->loadBlob(ArtifactKeep::Yes, &blob);
ComPtr<ISlangBlob> blob;
outArtifact->loadBlob(ArtifactKeep::Yes, blob.writeRef());

// Add the precompiled blob to the module
builder.setInsertInto(module);
Expand Down
4 changes: 4 additions & 0 deletions source/slang/slang-ir-insts.h
Original file line number Diff line number Diff line change
Expand Up @@ -3624,6 +3624,10 @@ struct IRBuilder
IRGenericKind* getGenericKind();

IRPtrType* getPtrType(IRType* valueType);

// Form a ptr type to `valueType` using the same opcode and address space as `ptrWithAddrSpace`.
IRPtrTypeBase* getPtrTypeWithAddressSpace(IRType* valueType, IRPtrTypeBase* ptrWithAddrSpace);

IROutType* getOutType(IRType* valueType);
IRInOutType* getInOutType(IRType* valueType);
IRRefType* getRefType(IRType* valueType, AddressSpace addrSpace);
Expand Down
9 changes: 6 additions & 3 deletions source/slang/slang-ir-legalize-types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -934,6 +934,8 @@ static LegalVal legalizeStore(

case LegalVal::Flavor::simple:
{
if (legalVal.flavor == LegalVal::Flavor::none)
return LegalVal();
context->builder->emitStore(legalPtrVal.getSimple(), legalVal.getSimple());
return legalVal;
}
Expand Down Expand Up @@ -2248,7 +2250,7 @@ static LegalVal legalizeLocalVar(
// Easy case: the type is usable as-is, and we
// should just do that.
auto type = maybeSimpleType.getSimple();
type = context->builder->getPtrType(type);
type = context->builder->getPtrTypeWithAddressSpace(type, irLocalVar->getDataType());
if( originalRate )
{
type = context->builder->getRateQualifiedType(
Expand Down Expand Up @@ -3669,7 +3671,7 @@ static LegalVal legalizeGlobalVar(
auto legalValueType = legalizeType(
context,
originalValueType);

auto varPtrType = as<IRPtrTypeBase>(irGlobalVar->getDataType());
switch (legalValueType.flavor)
{
case LegalType::Flavor::simple:
Expand All @@ -3678,7 +3680,8 @@ static LegalVal legalizeGlobalVar(
context->builder->setDataType(
irGlobalVar,
context->builder->getPtrType(
legalValueType.getSimple()));
legalValueType.getSimple(),
varPtrType ? varPtrType->getAddressSpace():AddressSpace::Global));
return LegalVal::simple(irGlobalVar);

default:
Expand Down
31 changes: 26 additions & 5 deletions source/slang/slang-ir-lower-buffer-element-type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -671,7 +671,26 @@ namespace Slang
if (auto unsizedArrayType = as<IRUnsizedArrayType>(ptrType->getValueType()))
{
builder.setInsertBefore(ptrVal);
auto newArrayPtrVal = builder.emitGetOffsetPtr(fieldAddr->getBase(), builder.getIntValue(builder.getIntType(), 1));
auto newArrayPtrVal = fieldAddr->getBase();
// Is base a pointer to an empty struct? If so, don't offset it.
// For example, if the user has written:
// ```
// struct S {int arr[]};
// uniform S* p;
// void test() { p->arr[1]; }
// ```
// Then `S` will become an empty struct after we remove `arr[]`.
// And `p` will be come a `void*`.
// We don't want to offset `p` to `p+1` to get the starting address of the array in this case.
IRSizeAndAlignment parentStructSize = {};
getNaturalSizeAndAlignment(
target->getOptionSet(),
tryGetPointedToType(&builder, fieldAddr->getBase()->getDataType()),
&parentStructSize);
if (parentStructSize.size != 0)
{
newArrayPtrVal = builder.emitGetOffsetPtr(fieldAddr->getBase(), builder.getIntValue(builder.getIntType(), 1));
}
auto loweredInnerType = getLoweredTypeInfo(unsizedArrayType->getElementType(), layoutRules);

IRSizeAndAlignment arrayElementSizeAlignment;
Expand All @@ -685,12 +704,14 @@ namespace Slang
&baseSizeAlignment);

// Convert pointer to uint64 and adjust offset.
auto rawPtr = builder.emitBitCast(builder.getUInt64Type(), newArrayPtrVal);
IRIntegerValue offset = baseSizeAlignment.size;
offset = align(offset, arrayElementSizeAlignment.alignment);
newArrayPtrVal = builder.emitAdd(rawPtr->getFullType(), rawPtr,
builder.getIntValue(builder.getUInt64Type(), offset));

if (offset != 0)
{
auto rawPtr = builder.emitBitCast(builder.getUInt64Type(), newArrayPtrVal);
newArrayPtrVal = builder.emitAdd(rawPtr->getFullType(), rawPtr,
builder.getIntValue(builder.getUInt64Type(), offset));
}
newArrayPtrVal = builder.emitBitCast(
builder.getPtrType(loweredInnerType.loweredType,
ptrType->getAddressSpace()), newArrayPtrVal);
Expand Down
16 changes: 11 additions & 5 deletions source/slang/slang-ir-spirv-legalize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include "slang-ir-loop-unroll.h"
#include "slang-ir-lower-buffer-element-type.h"
#include "slang-ir-specialize-address-space.h"
#include "slang-legalize-types.h"

namespace Slang
{
Expand All @@ -37,6 +38,8 @@ struct SPIRVLegalizationContext : public SourceEmitterBase

IRModule* m_module;

DiagnosticSink* m_sink;

struct LoweredStructuredBufferTypeInfo
{
IRType* structType;
Expand Down Expand Up @@ -173,8 +176,8 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
}
}

SPIRVLegalizationContext(SPIRVEmitSharedContext* sharedContext, IRModule* module)
: m_sharedContext(sharedContext), m_module(module)
SPIRVLegalizationContext(SPIRVEmitSharedContext* sharedContext, IRModule* module, DiagnosticSink* sink)
: m_sharedContext(sharedContext), m_module(module), m_sink(sink)
{
}

Expand Down Expand Up @@ -2108,6 +2111,9 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
// safely lower the pointer load stores early together with other buffer types.
lowerBufferElementTypeToStorageType(m_sharedContext->m_targetProgram, m_module, true);

// The above step may produce empty struct types, so we need to lower them out of existence.
legalizeEmptyTypes(m_sharedContext->m_targetProgram, m_module, m_sink);

// Specalize address space for all pointers.
SpirvAddressSpaceAssigner addressSpaceAssigner;
specializeAddressSpace(m_module, &addressSpaceAssigner);
Expand Down Expand Up @@ -2184,9 +2190,9 @@ SpvSnippet* SPIRVEmitSharedContext::getParsedSpvSnippet(IRTargetIntrinsicDecorat
return snippet;
}

void legalizeSPIRV(SPIRVEmitSharedContext* sharedContext, IRModule* module)
void legalizeSPIRV(SPIRVEmitSharedContext* sharedContext, IRModule* module, DiagnosticSink* sink)
{
SPIRVLegalizationContext context(sharedContext, module);
SPIRVLegalizationContext context(sharedContext, module, sink);
context.processModule();
}

Expand Down Expand Up @@ -2326,7 +2332,7 @@ void legalizeIRForSPIRV(
CodeGenContext* codeGenContext)
{
SLANG_UNUSED(entryPoints);
legalizeSPIRV(context, module);
legalizeSPIRV(context, module, codeGenContext->getSink());
simplifyIRForSpirvLegalization(context->m_targetProgram, codeGenContext->getSink(), module);
buildEntryPointReferenceGraph(context->m_referencingEntryPoints, module);
insertFragmentShaderInterlock(context, module);
Expand Down
7 changes: 7 additions & 0 deletions source/slang/slang-ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2881,6 +2881,13 @@ namespace Slang
operands);
}

IRPtrTypeBase* IRBuilder::getPtrTypeWithAddressSpace(IRType* valueType, IRPtrTypeBase* ptrWithAddrSpace)
{
if (ptrWithAddrSpace->hasAddressSpace())
return (IRPtrTypeBase*)getPtrType(ptrWithAddrSpace->getOp(), valueType, ptrWithAddrSpace->getAddressSpace());
return (IRPtrTypeBase*)getPtrType(ptrWithAddrSpace->getOp(), valueType);
}

IRPtrType* IRBuilder::getPtrType(IROp op, IRType* valueType, AddressSpace addressSpace)
{
return (IRPtrType*)getPtrType(op, valueType, getIntValue(getUInt64Type(), static_cast<IRIntegerValue>(addressSpace)));
Expand Down
42 changes: 32 additions & 10 deletions source/slang/slang-legalize-types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -896,24 +896,46 @@ static LegalType createLegalUniformBufferType(
// Create a pointer type with a given legalized value type.
static LegalType createLegalPtrType(
TypeLegalizationContext* context,
IROp op,
IRInst* originalPtrType,
LegalType legalValueType)
{
switch (legalValueType.flavor)
{
case LegalType::Flavor::none:
if (auto ptrType = as<IRPtrType>(originalPtrType))
{
switch (ptrType->getAddressSpace())
{
case AddressSpace::UserPointer:
case AddressSpace::Global:
// If this is a physical pointer, we need to create an untyped pointer if
// the element type is nothing.
return LegalType::simple(
context->getBuilder()->getPtrTypeWithAddressSpace(
context->getBuilder()->getVoidType(),
ptrType));
}
}
return LegalType();

case LegalType::Flavor::simple:
{
// Easy case: we just have a simple element type,
// so we want to create a uniform buffer that wraps it.
// Easy case: we just have a simple element type.
if (auto ptrTypeBase = as<IRPtrTypeBase>(originalPtrType))
{
if (ptrTypeBase->hasAddressSpace())
{
return LegalType::simple(
context->getBuilder()->getPtrTypeWithAddressSpace(
legalValueType.getSimple(),
ptrTypeBase));
}
}
return LegalType::simple(createBuiltinGenericType(
context,
op,
originalPtrType->getOp(),
legalValueType.getSimple()));
}
break;

case LegalType::Flavor::implicitDeref:
{
Expand All @@ -936,7 +958,7 @@ static LegalType createLegalPtrType(
// will matter.
return LegalType::implicitDeref(createLegalPtrType(
context,
op,
originalPtrType,
legalValueType.getImplicitDeref()->valueType));
}
break;
Expand All @@ -948,11 +970,11 @@ static LegalType createLegalPtrType(

auto ordinaryType = createLegalPtrType(
context,
op,
originalPtrType,
pairType->ordinaryType);
auto specialType = createLegalPtrType(
context,
op,
originalPtrType,
pairType->specialType);

return LegalType::pair(ordinaryType, specialType, pairType->pairInfo);
Expand All @@ -974,7 +996,7 @@ static LegalType createLegalPtrType(
newElement.key = ee.key;
newElement.type = createLegalPtrType(
context,
op,
originalPtrType,
ee.type);

ptrPseudoTupleType->elements.add(newElement);
Expand Down Expand Up @@ -1310,7 +1332,7 @@ LegalType legalizeTypeImpl(
if (legalValueType.flavor == LegalType::Flavor::simple &&
legalValueType.getSimple() == ptrType->getValueType())
return LegalType::simple(ptrType);
return createLegalPtrType(context, ptrType->getOp(), legalValueType);
return createLegalPtrType(context, ptrType, legalValueType);
}
else if(auto structType = as<IRStructType>(type))
{
Expand Down
15 changes: 15 additions & 0 deletions tests/spirv/ptr-empty-struct.slang
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
//TEST:SIMPLE(filecheck=CHECK): -target spirv

// CHECK: OpPtrAccessChain

struct EmptyStruct {
};

[vk::push_constant] EmptyStruct* pc;

RWStructuredBuffer<int> outputBuffer;

[numthreads(64)]
void ComputeMain(uint tid: SV_DispatchThreadID) {
outputBuffer[tid] = ((int*)(pc))[0];
}
25 changes: 25 additions & 0 deletions tests/spirv/ptr-unsized-array-2.slang
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
//TEST:SIMPLE(filecheck=CHECK): -target spirv

// CHECK-DAG: %[[cbuffer__t:[A-Za-z0-9_]+]] = OpTypeStruct %_ptr_PhysicalStorageBuffer_uint
// CHECK-DAG: %light_buffer = OpVariable %_ptr_PushConstant_[[cbuffer__t]] PushConstant

// CHECK: OpAccessChain %_ptr_PushConstant
// CHECK-NEXT: OpLoad
// CHECK-NEXT: OpBitcast %_ptr_PhysicalStorageBuffer

struct LightBuffer {
uint8_t lights[];
}

[vk::push_constant]
LightBuffer* light_buffer;

[shader("vertex")]
float4 vertMain() : SV_Position {
return float4(light_buffer.lights[0]);
}

[shader("fragment")]
float4 fragMain() : COLOR0 {
return float4(1.0);
}

0 comments on commit 3e84726

Please sign in to comment.