Skip to content

Commit

Permalink
Support groupshared variables for Metal. (shader-slang#4116)
Browse files Browse the repository at this point in the history
  • Loading branch information
csyonghe authored May 7, 2024
1 parent 618428a commit 1b3a428
Show file tree
Hide file tree
Showing 5 changed files with 151 additions and 17 deletions.
42 changes: 38 additions & 4 deletions source/slang/slang-emit-metal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -494,7 +494,7 @@ void MetalSourceEmitter::emitSimpleTypeImpl(IRType* type)
case kIROp_ParameterBlockType:
case kIROp_ConstantBufferType:
{
emitType((IRType*)type->getOperand(0));
emitSimpleTypeImpl((IRType*)type->getOperand(0));
m_writer->emit(" constant*");
return;
}
Expand Down Expand Up @@ -607,11 +607,17 @@ void MetalSourceEmitter::emitSimpleTypeImpl(IRType* type)
}
}

void MetalSourceEmitter::emitRateQualifiersAndAddressSpaceImpl(IRRate* rate, [[maybe_unused]] IRIntegerValue addressSpace)
void MetalSourceEmitter::_emitType(IRType* type, DeclaratorInfo* declarator)
{
if (as<IRGroupSharedRate>(rate))
switch (type->getOp())
{
m_writer->emit("threadgroup ");
case kIROp_ArrayType:
emitSimpleType(type);
emitDeclarator(declarator);
break;
default:
Super::_emitType(type, declarator);
break;
}
}

Expand Down Expand Up @@ -796,6 +802,34 @@ void MetalSourceEmitter::emitPackOffsetModifier(IRInst* varInst, IRType* valueTy
// We emit packoffset as a semantic in `emitSemantic`, so nothing to do here.
}

void MetalSourceEmitter::emitRateQualifiersAndAddressSpaceImpl(IRRate* rate, IRIntegerValue addressSpace)
{
if (as<IRGroupSharedRate>(rate))
{
m_writer->emit("threadgroup ");
return;
}

switch ((AddressSpace)addressSpace)
{
case AddressSpace::GroupShared:
m_writer->emit("threadgroup ");
break;
case AddressSpace::Uniform:
m_writer->emit("constant ");
break;
case AddressSpace::Global:
m_writer->emit("device ");
break;
case AddressSpace::ThreadLocal:
m_writer->emit("thread ");
break;
default:
break;
}
}


void MetalSourceEmitter::emitMeshShaderModifiersImpl(IRInst* varInst)
{
SLANG_UNUSED(varInst);
Expand Down
2 changes: 2 additions & 0 deletions source/slang/slang-emit-metal.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ class MetalSourceEmitter : public CLikeSourceEmitter

void emitFuncParamLayoutImpl(IRInst* param);

virtual void _emitType(IRType* type, DeclaratorInfo* declarator) SLANG_OVERRIDE;

void _emitHLSLParameterGroup(IRGlobalParam* varDecl, IRUniformParameterGroupType* type);

void _emitHLSLTextureType(IRTextureTypeBase* texType);
Expand Down
90 changes: 77 additions & 13 deletions source/slang/slang-ir-explicit-global-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,11 @@ struct IntroduceExplicitGlobalContextPass
List<IRGlobalVar*> m_globalVars;
List<IRFunc*> m_entryPoints;

enum class GlobalObjectKind
{
GlobalParam, GlobalVar
};

void processModule()
{
IRBuilder builder(m_module);
Expand Down Expand Up @@ -181,14 +186,14 @@ struct IntroduceExplicitGlobalContextPass
// parameters, we create a field that exactly matches its type.
//

createContextStructField(globalParam, globalParam->getFullType());
createContextStructField(globalParam, GlobalObjectKind::GlobalParam, globalParam->getFullType());
}
for( auto globalVar : m_globalVars )
{
// A `IRGlobalVar` represents a pointer to where the variable is stored,
// so we need to create a field of the pointed-to type to represent it.
//
createContextStructField(globalVar, globalVar->getDataType()->getValueType());
createContextStructField(globalVar, GlobalObjectKind::GlobalVar, getGlobalVarPtrType(globalVar));
}

// Once all the fields have been created, we can process the entry points.
Expand Down Expand Up @@ -229,22 +234,46 @@ struct IntroduceExplicitGlobalContextPass
// variable parameter, and to record the context pointer
// value to use for a function.
//
Dictionary<IRInst*, IRStructKey*> m_mapInstToContextFieldKey;
struct ContextFieldInfo
{
IRStructKey* key = nullptr;

// Is this field a pointer to the actual value?
// For groupshared variables, this will be true.
bool needDereference = false;
};
Dictionary<IRInst*, ContextFieldInfo> m_mapInstToContextFieldInfo;
Dictionary<IRFunc*, IRInst*> m_mapFuncToContextPtr;

void createContextStructField(IRInst* originalInst, IRType* type)
void createContextStructField(IRInst* originalInst, GlobalObjectKind kind, IRType* type)
{
// Creating a field in the context struct to represent
// `originalInst` is straightforward.

IRBuilder builder(m_module);
builder.setInsertBefore(m_contextStructType);

IRType* fieldDataType = type;
bool needDereference = false;
if (kind == GlobalObjectKind::GlobalVar)
{
auto ptrType = as<IRPtrTypeBase>(type);
if (ptrType->getAddressSpace() == (IRIntegerValue)AddressSpace::GroupShared)
{
fieldDataType = ptrType;
needDereference = true;
}
else
{
fieldDataType = as<IRPtrTypeBase>(type)->getValueType();
}
}

// We create a "key" for the new field, and then a field
// of the appropraite type.
//
auto key = builder.createStructKey();
builder.createStructField(m_contextStructType, key, type);
builder.createStructField(m_contextStructType, key, fieldDataType);

// Clone all original decorations to the new struct key.
IRCloneEnv cloneEnv;
Expand All @@ -254,7 +283,7 @@ struct IntroduceExplicitGlobalContextPass
// for the instruction, so that we can use the key
// to access the field later.
//
m_mapInstToContextFieldKey.add(originalInst, key);
m_mapInstToContextFieldInfo.add(originalInst, ContextFieldInfo{ key, needDereference });
}

void createContextForEntryPoint(IRFunc* entryPointFunc)
Expand Down Expand Up @@ -321,14 +350,14 @@ struct IntroduceExplicitGlobalContextPass
//
for (auto entryPointParam : entryPointParams)
{
auto fieldKey = m_mapInstToContextFieldKey[entryPointParam.globalParam];
auto fieldInfo = m_mapInstToContextFieldInfo[entryPointParam.globalParam];
auto fieldType = entryPointParam.globalParam->getFullType();
auto fieldPtrType = builder.getPtrType(fieldType);

// We compute the addrress of the field and store the
// value of the parameter into it.
//
auto fieldPtr = builder.emitFieldAddress(fieldPtrType, contextVarPtr, fieldKey);
auto fieldPtr = builder.emitFieldAddress(fieldPtrType, contextVarPtr, fieldInfo.key);
builder.emitStore(fieldPtr, entryPointParam.entryPointParam);
}

Expand All @@ -341,6 +370,27 @@ struct IntroduceExplicitGlobalContextPass
// run the pass in `slang-ir-explicit-global-init` first,
// in order to move all initialization of globals into the
// entry point functions.
//
// To support groupshared variables on Metal,we need to allocate the
// memory by defining a local variable in the entry point, and pass
// the address of that variable to the context.
//
for (auto globalVar : m_globalVars)
{
auto fieldInfo = m_mapInstToContextFieldInfo[globalVar];
if (fieldInfo.needDereference)
{
auto var = builder.emitVar(globalVar->getDataType()->getValueType(), (IRIntegerValue)AddressSpace::GroupShared);
if (auto nameDecor = globalVar->findDecoration<IRNameHintDecoration>())
{
builder.addNameHintDecoration(var, nameDecor->getName());
}
auto ptrPtrType = builder.getPtrType(getGlobalVarPtrType(globalVar), AddressSpace::ThreadLocal);
auto fieldPtr = builder.emitFieldAddress(ptrPtrType, contextVarPtr, fieldInfo.key);
builder.emitStore(fieldPtr, var);
}
}

}

void replaceUsesOfGlobalParam(IRGlobalParam* globalParam)
Expand All @@ -350,7 +400,7 @@ struct IntroduceExplicitGlobalContextPass
// A global shader parameter was mapped to a field
// in the context structure, so we find the appropriate key.
//
auto key = m_mapInstToContextFieldKey[globalParam];
auto fieldInfo = m_mapInstToContextFieldInfo[globalParam];

auto valType = globalParam->getFullType();
auto ptrType = builder.getPtrType(valType);
Expand All @@ -375,22 +425,34 @@ struct IntroduceExplicitGlobalContextPass
// taking the address of the corresponding field
// in the context struct and loading from it.
//
auto ptr = builder.emitFieldAddress(ptrType, contextParam, key);
auto ptr = builder.emitFieldAddress(ptrType, contextParam, fieldInfo.key);
auto val = builder.emitLoad(valType, ptr);
use->set(val);
}
}

IRType* getGlobalVarPtrType(IRGlobalVar* globalVar)
{
IRBuilder builder(globalVar);
if (as<IRGroupSharedRate>(globalVar->getRate()))
{
return builder.getPtrType(globalVar->getDataType()->getValueType(), AddressSpace::GroupShared);
}
return builder.getPtrType(globalVar->getDataType()->getValueType(), AddressSpace::ThreadLocal);
}

void replaceUsesOfGlobalVar(IRGlobalVar* globalVar)
{
IRBuilder builder(m_module);

// A global variable was mapped to a field
// in the context structure, so we find the appropriate key.
//
auto key = m_mapInstToContextFieldKey[globalVar];
auto fieldInfo = m_mapInstToContextFieldInfo[globalVar];

auto ptrType = globalVar->getDataType();
auto ptrType = getGlobalVarPtrType(globalVar);
if (fieldInfo.needDereference)
ptrType = builder.getPtrType(kIROp_PtrType, ptrType, AddressSpace::ThreadLocal);

// We then iterate over the uses of the variable,
// being careful to defend against the use/def information
Expand All @@ -412,7 +474,9 @@ struct IntroduceExplicitGlobalContextPass
// taking the address of the corresponding field
// in the context struct.
//
auto ptr = builder.emitFieldAddress(ptrType, contextParam, key);
auto ptr = builder.emitFieldAddress(ptrType, contextParam, fieldInfo.key);
if (fieldInfo.needDereference)
ptr = builder.emitLoad(ptr);
use->set(ptr);
}
}
Expand Down
3 changes: 3 additions & 0 deletions source/slang/slang-ir-insts.h
Original file line number Diff line number Diff line change
Expand Up @@ -3435,6 +3435,9 @@ struct IRBuilder
IRConstRefType* getConstRefType(IRType* valueType);
IRPtrTypeBase* getPtrType(IROp op, IRType* valueType);
IRPtrType* getPtrType(IROp op, IRType* valueType, IRIntegerValue addressSpace);
IRPtrType* getPtrType(IROp op, IRType* valueType, AddressSpace addressSpace) { return getPtrType(op, valueType, (IRIntegerValue)addressSpace); }
IRPtrType* getPtrType(IRType* valueType, AddressSpace addressSpace) { return getPtrType(kIROp_PtrType, valueType, (IRIntegerValue)addressSpace); }

IRTextureTypeBase* getTextureType(
IRType* elementType,
IRInst* shape,
Expand Down
31 changes: 31 additions & 0 deletions tests/metal/groupshared.slang
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
//TEST:SIMPLE(filecheck=CHECK): -target metal
//TEST:SIMPLE(filecheck=CHECK-ASM): -target metallib

uniform RWStructuredBuffer<float> outputBuffer;

struct MyBlock
{
StructuredBuffer<float> b1;
StructuredBuffer<float> b2;
}
ParameterBlock<MyBlock> block;

groupshared int myArr[16];

void func(float v)
{
outputBuffer[0] = myArr[0];
}

// CHECK: array<int, int(16)> threadgroup* myArr{{.*}};
// CHECK: {{\[\[}}kernel{{\]\]}} void main_kernel
// CHECK: threadgroup array<int, int(16)> myArr{{.*}};
// CHECK: (&kernelContext{{.*}})->myArr{{.*}} = &myArr{{.*}};
// CHECK-ASM: define void @main_kernel

[numthreads(1,1,1)]
void main_kernel(uint3 tid: SV_DispatchThreadID)
{
myArr[tid.x] = tid.x;
func(3.0f);
}

0 comments on commit 1b3a428

Please sign in to comment.