From 1b3a428bfa24350d9d69b092747b4ad142b7c4b4 Mon Sep 17 00:00:00 2001 From: Yong He Date: Mon, 6 May 2024 19:21:03 -0700 Subject: [PATCH] Support groupshared variables for Metal. (#4116) --- source/slang/slang-emit-metal.cpp | 42 ++++++++- source/slang/slang-emit-metal.h | 2 + .../slang-ir-explicit-global-context.cpp | 90 ++++++++++++++++--- source/slang/slang-ir-insts.h | 3 + tests/metal/groupshared.slang | 31 +++++++ 5 files changed, 151 insertions(+), 17 deletions(-) create mode 100644 tests/metal/groupshared.slang diff --git a/source/slang/slang-emit-metal.cpp b/source/slang/slang-emit-metal.cpp index 7580ed74dc..2c327b6134 100644 --- a/source/slang/slang-emit-metal.cpp +++ b/source/slang/slang-emit-metal.cpp @@ -494,7 +494,7 @@ void MetalSourceEmitter::emitSimpleTypeImpl(IRType* type) case kIROp_ParameterBlockType: case kIROp_ConstantBufferType: { - emitType((IRType*)type->getOperand(0)); + emitSimpleTypeImpl((IRType*)type->getOperand(0)); m_writer->emit(" constant*"); return; } @@ -607,11 +607,17 @@ void MetalSourceEmitter::emitSimpleTypeImpl(IRType* type) } } -void MetalSourceEmitter::emitRateQualifiersAndAddressSpaceImpl(IRRate* rate, [[maybe_unused]] IRIntegerValue addressSpace) +void MetalSourceEmitter::_emitType(IRType* type, DeclaratorInfo* declarator) { - if (as(rate)) + switch (type->getOp()) { - m_writer->emit("threadgroup "); + case kIROp_ArrayType: + emitSimpleType(type); + emitDeclarator(declarator); + break; + default: + Super::_emitType(type, declarator); + break; } } @@ -796,6 +802,34 @@ void MetalSourceEmitter::emitPackOffsetModifier(IRInst* varInst, IRType* valueTy // We emit packoffset as a semantic in `emitSemantic`, so nothing to do here. } +void MetalSourceEmitter::emitRateQualifiersAndAddressSpaceImpl(IRRate* rate, IRIntegerValue addressSpace) +{ + if (as(rate)) + { + m_writer->emit("threadgroup "); + return; + } + + switch ((AddressSpace)addressSpace) + { + case AddressSpace::GroupShared: + m_writer->emit("threadgroup "); + break; + case AddressSpace::Uniform: + m_writer->emit("constant "); + break; + case AddressSpace::Global: + m_writer->emit("device "); + break; + case AddressSpace::ThreadLocal: + m_writer->emit("thread "); + break; + default: + break; + } +} + + void MetalSourceEmitter::emitMeshShaderModifiersImpl(IRInst* varInst) { SLANG_UNUSED(varInst); diff --git a/source/slang/slang-emit-metal.h b/source/slang/slang-emit-metal.h index d925365daf..fc13901433 100644 --- a/source/slang/slang-emit-metal.h +++ b/source/slang/slang-emit-metal.h @@ -57,6 +57,8 @@ class MetalSourceEmitter : public CLikeSourceEmitter void emitFuncParamLayoutImpl(IRInst* param); + virtual void _emitType(IRType* type, DeclaratorInfo* declarator) SLANG_OVERRIDE; + void _emitHLSLParameterGroup(IRGlobalParam* varDecl, IRUniformParameterGroupType* type); void _emitHLSLTextureType(IRTextureTypeBase* texType); diff --git a/source/slang/slang-ir-explicit-global-context.cpp b/source/slang/slang-ir-explicit-global-context.cpp index f63ceb71e6..9bbaf38759 100644 --- a/source/slang/slang-ir-explicit-global-context.cpp +++ b/source/slang/slang-ir-explicit-global-context.cpp @@ -24,6 +24,11 @@ struct IntroduceExplicitGlobalContextPass List m_globalVars; List m_entryPoints; + enum class GlobalObjectKind + { + GlobalParam, GlobalVar + }; + void processModule() { IRBuilder builder(m_module); @@ -181,14 +186,14 @@ struct IntroduceExplicitGlobalContextPass // parameters, we create a field that exactly matches its type. // - createContextStructField(globalParam, globalParam->getFullType()); + createContextStructField(globalParam, GlobalObjectKind::GlobalParam, globalParam->getFullType()); } for( auto globalVar : m_globalVars ) { // A `IRGlobalVar` represents a pointer to where the variable is stored, // so we need to create a field of the pointed-to type to represent it. // - createContextStructField(globalVar, globalVar->getDataType()->getValueType()); + createContextStructField(globalVar, GlobalObjectKind::GlobalVar, getGlobalVarPtrType(globalVar)); } // Once all the fields have been created, we can process the entry points. @@ -229,10 +234,18 @@ struct IntroduceExplicitGlobalContextPass // variable parameter, and to record the context pointer // value to use for a function. // - Dictionary m_mapInstToContextFieldKey; + struct ContextFieldInfo + { + IRStructKey* key = nullptr; + + // Is this field a pointer to the actual value? + // For groupshared variables, this will be true. + bool needDereference = false; + }; + Dictionary m_mapInstToContextFieldInfo; Dictionary m_mapFuncToContextPtr; - void createContextStructField(IRInst* originalInst, IRType* type) + void createContextStructField(IRInst* originalInst, GlobalObjectKind kind, IRType* type) { // Creating a field in the context struct to represent // `originalInst` is straightforward. @@ -240,11 +253,27 @@ struct IntroduceExplicitGlobalContextPass IRBuilder builder(m_module); builder.setInsertBefore(m_contextStructType); + IRType* fieldDataType = type; + bool needDereference = false; + if (kind == GlobalObjectKind::GlobalVar) + { + auto ptrType = as(type); + if (ptrType->getAddressSpace() == (IRIntegerValue)AddressSpace::GroupShared) + { + fieldDataType = ptrType; + needDereference = true; + } + else + { + fieldDataType = as(type)->getValueType(); + } + } + // We create a "key" for the new field, and then a field // of the appropraite type. // auto key = builder.createStructKey(); - builder.createStructField(m_contextStructType, key, type); + builder.createStructField(m_contextStructType, key, fieldDataType); // Clone all original decorations to the new struct key. IRCloneEnv cloneEnv; @@ -254,7 +283,7 @@ struct IntroduceExplicitGlobalContextPass // for the instruction, so that we can use the key // to access the field later. // - m_mapInstToContextFieldKey.add(originalInst, key); + m_mapInstToContextFieldInfo.add(originalInst, ContextFieldInfo{ key, needDereference }); } void createContextForEntryPoint(IRFunc* entryPointFunc) @@ -321,14 +350,14 @@ struct IntroduceExplicitGlobalContextPass // for (auto entryPointParam : entryPointParams) { - auto fieldKey = m_mapInstToContextFieldKey[entryPointParam.globalParam]; + auto fieldInfo = m_mapInstToContextFieldInfo[entryPointParam.globalParam]; auto fieldType = entryPointParam.globalParam->getFullType(); auto fieldPtrType = builder.getPtrType(fieldType); // We compute the addrress of the field and store the // value of the parameter into it. // - auto fieldPtr = builder.emitFieldAddress(fieldPtrType, contextVarPtr, fieldKey); + auto fieldPtr = builder.emitFieldAddress(fieldPtrType, contextVarPtr, fieldInfo.key); builder.emitStore(fieldPtr, entryPointParam.entryPointParam); } @@ -341,6 +370,27 @@ struct IntroduceExplicitGlobalContextPass // run the pass in `slang-ir-explicit-global-init` first, // in order to move all initialization of globals into the // entry point functions. + // + // To support groupshared variables on Metal,we need to allocate the + // memory by defining a local variable in the entry point, and pass + // the address of that variable to the context. + // + for (auto globalVar : m_globalVars) + { + auto fieldInfo = m_mapInstToContextFieldInfo[globalVar]; + if (fieldInfo.needDereference) + { + auto var = builder.emitVar(globalVar->getDataType()->getValueType(), (IRIntegerValue)AddressSpace::GroupShared); + if (auto nameDecor = globalVar->findDecoration()) + { + builder.addNameHintDecoration(var, nameDecor->getName()); + } + auto ptrPtrType = builder.getPtrType(getGlobalVarPtrType(globalVar), AddressSpace::ThreadLocal); + auto fieldPtr = builder.emitFieldAddress(ptrPtrType, contextVarPtr, fieldInfo.key); + builder.emitStore(fieldPtr, var); + } + } + } void replaceUsesOfGlobalParam(IRGlobalParam* globalParam) @@ -350,7 +400,7 @@ struct IntroduceExplicitGlobalContextPass // A global shader parameter was mapped to a field // in the context structure, so we find the appropriate key. // - auto key = m_mapInstToContextFieldKey[globalParam]; + auto fieldInfo = m_mapInstToContextFieldInfo[globalParam]; auto valType = globalParam->getFullType(); auto ptrType = builder.getPtrType(valType); @@ -375,12 +425,22 @@ struct IntroduceExplicitGlobalContextPass // taking the address of the corresponding field // in the context struct and loading from it. // - auto ptr = builder.emitFieldAddress(ptrType, contextParam, key); + auto ptr = builder.emitFieldAddress(ptrType, contextParam, fieldInfo.key); auto val = builder.emitLoad(valType, ptr); use->set(val); } } + IRType* getGlobalVarPtrType(IRGlobalVar* globalVar) + { + IRBuilder builder(globalVar); + if (as(globalVar->getRate())) + { + return builder.getPtrType(globalVar->getDataType()->getValueType(), AddressSpace::GroupShared); + } + return builder.getPtrType(globalVar->getDataType()->getValueType(), AddressSpace::ThreadLocal); + } + void replaceUsesOfGlobalVar(IRGlobalVar* globalVar) { IRBuilder builder(m_module); @@ -388,9 +448,11 @@ struct IntroduceExplicitGlobalContextPass // A global variable was mapped to a field // in the context structure, so we find the appropriate key. // - auto key = m_mapInstToContextFieldKey[globalVar]; + auto fieldInfo = m_mapInstToContextFieldInfo[globalVar]; - auto ptrType = globalVar->getDataType(); + auto ptrType = getGlobalVarPtrType(globalVar); + if (fieldInfo.needDereference) + ptrType = builder.getPtrType(kIROp_PtrType, ptrType, AddressSpace::ThreadLocal); // We then iterate over the uses of the variable, // being careful to defend against the use/def information @@ -412,7 +474,9 @@ struct IntroduceExplicitGlobalContextPass // taking the address of the corresponding field // in the context struct. // - auto ptr = builder.emitFieldAddress(ptrType, contextParam, key); + auto ptr = builder.emitFieldAddress(ptrType, contextParam, fieldInfo.key); + if (fieldInfo.needDereference) + ptr = builder.emitLoad(ptr); use->set(ptr); } } diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index 9329e38063..5c4f01ae7c 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -3435,6 +3435,9 @@ struct IRBuilder IRConstRefType* getConstRefType(IRType* valueType); IRPtrTypeBase* getPtrType(IROp op, IRType* valueType); IRPtrType* getPtrType(IROp op, IRType* valueType, IRIntegerValue addressSpace); + IRPtrType* getPtrType(IROp op, IRType* valueType, AddressSpace addressSpace) { return getPtrType(op, valueType, (IRIntegerValue)addressSpace); } + IRPtrType* getPtrType(IRType* valueType, AddressSpace addressSpace) { return getPtrType(kIROp_PtrType, valueType, (IRIntegerValue)addressSpace); } + IRTextureTypeBase* getTextureType( IRType* elementType, IRInst* shape, diff --git a/tests/metal/groupshared.slang b/tests/metal/groupshared.slang new file mode 100644 index 0000000000..4d1f6ecac6 --- /dev/null +++ b/tests/metal/groupshared.slang @@ -0,0 +1,31 @@ +//TEST:SIMPLE(filecheck=CHECK): -target metal +//TEST:SIMPLE(filecheck=CHECK-ASM): -target metallib + +uniform RWStructuredBuffer outputBuffer; + +struct MyBlock +{ + StructuredBuffer b1; + StructuredBuffer b2; +} +ParameterBlock block; + +groupshared int myArr[16]; + +void func(float v) +{ + outputBuffer[0] = myArr[0]; +} + +// CHECK: array threadgroup* myArr{{.*}}; +// CHECK: {{\[\[}}kernel{{\]\]}} void main_kernel +// CHECK: threadgroup array myArr{{.*}}; +// CHECK: (&kernelContext{{.*}})->myArr{{.*}} = &myArr{{.*}}; +// CHECK-ASM: define void @main_kernel + +[numthreads(1,1,1)] +void main_kernel(uint3 tid: SV_DispatchThreadID) +{ + myArr[tid.x] = tid.x; + func(3.0f); +}