From 849f8b884b5e118ff46c1b16f5456e58704a2c3e Mon Sep 17 00:00:00 2001 From: Steven Perron Date: Tue, 9 Jan 2024 10:03:52 -0500 Subject: [PATCH] Handle bool in HS input and outputs (#6129) Inputs and outputs that are boolean in HLSL cannot be bools in SPIR-V because it is not allowed by the spec. So they need to be treated as ints in the interface. This was not tested with Hull shaders, where these input and outputs must also be turned into arrays. This commit adds the code that will handle casting the arrays between int and bool. Fixes #3744 --- tools/clang/lib/SPIRV/DeclResultIdMapper.cpp | 56 ++++++++++++++----- .../test/CodeGenSPIRV/hs.bool.input.hlsl | 47 ++++++++++++++++ 2 files changed, 89 insertions(+), 14 deletions(-) create mode 100644 tools/clang/test/CodeGenSPIRV/hs.bool.input.hlsl diff --git a/tools/clang/lib/SPIRV/DeclResultIdMapper.cpp b/tools/clang/lib/SPIRV/DeclResultIdMapper.cpp index 8a4ea0b746..d51ccef3a1 100644 --- a/tools/clang/lib/SPIRV/DeclResultIdMapper.cpp +++ b/tools/clang/lib/SPIRV/DeclResultIdMapper.cpp @@ -2766,6 +2766,23 @@ void DeclResultIdMapper::storeToShaderOutputVariable( const StageVarDataBundle &stageVarData) { SpirvInstruction *ptr = varInstr; + // Since boolean output stage variables are represented as unsigned + // integers, we must cast the value to uint before storing. + if (isBooleanStageIOVar(stageVarData.decl, stageVarData.type, + stageVarData.semantic->getKind(), + stageVarData.sigPoint->GetKind())) { + QualType finalType = varInstr->getAstResultType(); + if (stageVarData.arraySize != 0) { + // We assume that we will only have to write to a single value of the + // array, so we have to cast to the element type of the array, and not the + // array type. + assert(stageVarData.invocationId.hasValue()); + finalType = finalType->getAsArrayTypeUnsafe()->getElementType(); + } + value = theEmitter.castToType(value, stageVarData.type, finalType, + stageVarData.decl->getLocation()); + } + // Special handling of SV_TessFactor HS patch constant output. // TessLevelOuter is always an array of size 4 in SPIR-V, but // SV_TessFactor could be an array of size 2, 3, or 4 in HLSL. Only the @@ -2831,16 +2848,6 @@ void DeclResultIdMapper::storeToShaderOutputVariable( ptr->setStorageClass(spv::StorageClass::Output); spvBuilder.createStore(ptr, value, stageVarData.decl->getLocation()); } - // Since boolean output stage variables are represented as unsigned - // integers, we must cast the value to uint before storing. - else if (isBooleanStageIOVar(stageVarData.decl, stageVarData.type, - stageVarData.semantic->getKind(), - stageVarData.sigPoint->GetKind())) { - value = theEmitter.castToType(value, stageVarData.type, - varInstr->getAstResultType(), - stageVarData.decl->getLocation()); - spvBuilder.createStore(ptr, value, stageVarData.decl->getLocation()); - } // For all normal cases else { spvBuilder.createStore(ptr, value, stageVarData.decl->getLocation()); @@ -2983,9 +2990,30 @@ SpirvInstruction *DeclResultIdMapper::loadShaderInputVariable( if (isBooleanStageIOVar(stageVarData.decl, stageVarData.type, stageVarData.semantic->getKind(), stageVarData.sigPoint->GetKind())) { - load = theEmitter.castToType(load, varInstr->getAstResultType(), - stageVarData.type, - stageVarData.decl->getLocation()); + + if (stageVarData.arraySize == 0) { + load = theEmitter.castToType(load, varInstr->getAstResultType(), + stageVarData.type, + stageVarData.decl->getLocation()); + } else { + llvm::SmallVector fields; + SourceLocation loc = stageVarData.decl->getLocation(); + QualType originalScalarType = varInstr->getAstResultType() + ->castAsArrayTypeUnsafe() + ->getElementType(); + for (uint32_t idx = 0; idx < stageVarData.arraySize; ++idx) { + SpirvInstruction *field = spvBuilder.createCompositeExtract( + originalScalarType, load, {idx}, loc); + field = theEmitter.castToType(field, field->getAstResultType(), + stageVarData.type, loc); + fields.push_back(field); + } + + QualType finalType = astContext.getConstantArrayType( + stageVarData.type, llvm::APInt(32, stageVarData.arraySize), + clang::ArrayType::Normal, 0); + load = spvBuilder.createCompositeConstruct(finalType, fields, loc); + } } return load; } @@ -3237,7 +3265,7 @@ SpirvVariable *DeclResultIdMapper::createSpirvInterfaceVariable( const StageVarDataBundle &stageVarData) { // The evalType will be the type of the interface variable in SPIR-V. // The type of the variable used in the body of the function will still be - // `type`. + // `stageVarData.type`. QualType evalType = getTypeForSpirvStageVariable(stageVarData); const auto *builtinAttr = stageVarData.decl->getAttr(); diff --git a/tools/clang/test/CodeGenSPIRV/hs.bool.input.hlsl b/tools/clang/test/CodeGenSPIRV/hs.bool.input.hlsl new file mode 100644 index 0000000000..7f06b2bf87 --- /dev/null +++ b/tools/clang/test/CodeGenSPIRV/hs.bool.input.hlsl @@ -0,0 +1,47 @@ +// RUN: %dxc -T hs_6_0 -E Hull -fcgl %s -spirv | FileCheck %s + +struct ControlPoint +{ + bool b : MY_BOOL; +}; + +struct HullPatchOut { + float edge [3] : SV_TessFactor; + float inside : SV_InsideTessFactor; +}; + +// Check that the wrapper function correctly copies `v` as a parameter to Hull. +// CHECK: [[ld:%[0-9]+]] = OpLoad %_arr_uint_uint_3 %in_var_MY_BOOL +// CHECK: [[element1:%[0-9]+]] = OpCompositeExtract %uint [[ld:%[0-9]+]] 0 +// CHECK: [[bool1:%[0-9]+]] = OpINotEqual %bool [[element1]] %uint_0 +// CHECK: [[element2:%[0-9]+]] = OpCompositeExtract %uint [[ld:%[0-9]+]] 1 +// CHECK: [[bool2:%[0-9]+]] = OpINotEqual %bool [[element2]] %uint_0 +// CHECK: [[element3:%[0-9]+]] = OpCompositeExtract %uint [[ld:%[0-9]+]] 2 +// CHECK: [[bool3:%[0-9]+]] = OpINotEqual %bool [[element3]] %uint_0 +// CHECK: [[bool_array:%[0-9]+]] = OpCompositeConstruct %_arr_bool_uint_3 [[bool1]] [[bool2]] [[bool3]] +// CHECK: [[element1:%[0-9]+]] = OpCompositeExtract %bool [[bool_array]] 0 +// CHECK: [[cp1:%[0-9]+]] = OpCompositeConstruct %ControlPoint [[element1]] +// CHECK: [[element2:%[0-9]+]] = OpCompositeExtract %bool [[bool_array]] 1 +// CHECK: [[cp2:%[0-9]+]] = OpCompositeConstruct %ControlPoint [[element2]] +// CHECK: [[element3:%[0-9]+]] = OpCompositeExtract %bool [[bool_array]] 2 +// CHECK: [[cp3:%[0-9]+]] = OpCompositeConstruct %ControlPoint [[element3]] +// CHECK: [[v:%[0-9]+]] = OpCompositeConstruct %_arr_ControlPoint_uint_3 [[cp1]] [[cp2]] [[cp3]] +// CHECK: OpStore %param_var_v [[v]] +// CHECK: [[ret:%[0-9]+]] = OpFunctionCall %ControlPoint %src_Hull %param_var_v %param_var_id + +// Check that the return value is correctly copied to the output variable. +// CHECK: [[ret_bool:%[0-9]+]] = OpCompositeExtract %bool [[ret]] 0 +// CHECK: [[ret_int:%[0-9]+]] = OpSelect %uint [[ret_bool]] %uint_1 %uint_0 +// CHECK: [[out_var:%[0-9]+]] = OpAccessChain %_ptr_Output_uint %out_var_MY_BOOL +// CHECK: OpStore [[out_var]] [[ret_int]] + +[domain("tri")] +[partitioning("fractional_odd")] +[outputtopology("triangle_cw")] +[patchconstantfunc("HullConst")] +[outputcontrolpoints(3)] +ControlPoint Hull (InputPatch v, uint id : SV_OutputControlPointID) { return v[id]; } +HullPatchOut HullConst (InputPatch v) { return (HullPatchOut)0; } + +[domain("tri")] +float4 Domain (const OutputPatch vi) : SV_Position { return (vi[0].b ? 1 : 0); }