diff --git a/tools/clang/lib/SPIRV/DeclResultIdMapper.cpp b/tools/clang/lib/SPIRV/DeclResultIdMapper.cpp index 8a4ea0b746..d51ccef3a1 100644 --- a/tools/clang/lib/SPIRV/DeclResultIdMapper.cpp +++ b/tools/clang/lib/SPIRV/DeclResultIdMapper.cpp @@ -2766,6 +2766,23 @@ void DeclResultIdMapper::storeToShaderOutputVariable( const StageVarDataBundle &stageVarData) { SpirvInstruction *ptr = varInstr; + // Since boolean output stage variables are represented as unsigned + // integers, we must cast the value to uint before storing. + if (isBooleanStageIOVar(stageVarData.decl, stageVarData.type, + stageVarData.semantic->getKind(), + stageVarData.sigPoint->GetKind())) { + QualType finalType = varInstr->getAstResultType(); + if (stageVarData.arraySize != 0) { + // We assume that we will only have to write to a single value of the + // array, so we have to cast to the element type of the array, and not the + // array type. + assert(stageVarData.invocationId.hasValue()); + finalType = finalType->getAsArrayTypeUnsafe()->getElementType(); + } + value = theEmitter.castToType(value, stageVarData.type, finalType, + stageVarData.decl->getLocation()); + } + // Special handling of SV_TessFactor HS patch constant output. // TessLevelOuter is always an array of size 4 in SPIR-V, but // SV_TessFactor could be an array of size 2, 3, or 4 in HLSL. Only the @@ -2831,16 +2848,6 @@ void DeclResultIdMapper::storeToShaderOutputVariable( ptr->setStorageClass(spv::StorageClass::Output); spvBuilder.createStore(ptr, value, stageVarData.decl->getLocation()); } - // Since boolean output stage variables are represented as unsigned - // integers, we must cast the value to uint before storing. - else if (isBooleanStageIOVar(stageVarData.decl, stageVarData.type, - stageVarData.semantic->getKind(), - stageVarData.sigPoint->GetKind())) { - value = theEmitter.castToType(value, stageVarData.type, - varInstr->getAstResultType(), - stageVarData.decl->getLocation()); - spvBuilder.createStore(ptr, value, stageVarData.decl->getLocation()); - } // For all normal cases else { spvBuilder.createStore(ptr, value, stageVarData.decl->getLocation()); @@ -2983,9 +2990,30 @@ SpirvInstruction *DeclResultIdMapper::loadShaderInputVariable( if (isBooleanStageIOVar(stageVarData.decl, stageVarData.type, stageVarData.semantic->getKind(), stageVarData.sigPoint->GetKind())) { - load = theEmitter.castToType(load, varInstr->getAstResultType(), - stageVarData.type, - stageVarData.decl->getLocation()); + + if (stageVarData.arraySize == 0) { + load = theEmitter.castToType(load, varInstr->getAstResultType(), + stageVarData.type, + stageVarData.decl->getLocation()); + } else { + llvm::SmallVector fields; + SourceLocation loc = stageVarData.decl->getLocation(); + QualType originalScalarType = varInstr->getAstResultType() + ->castAsArrayTypeUnsafe() + ->getElementType(); + for (uint32_t idx = 0; idx < stageVarData.arraySize; ++idx) { + SpirvInstruction *field = spvBuilder.createCompositeExtract( + originalScalarType, load, {idx}, loc); + field = theEmitter.castToType(field, field->getAstResultType(), + stageVarData.type, loc); + fields.push_back(field); + } + + QualType finalType = astContext.getConstantArrayType( + stageVarData.type, llvm::APInt(32, stageVarData.arraySize), + clang::ArrayType::Normal, 0); + load = spvBuilder.createCompositeConstruct(finalType, fields, loc); + } } return load; } @@ -3237,7 +3265,7 @@ SpirvVariable *DeclResultIdMapper::createSpirvInterfaceVariable( const StageVarDataBundle &stageVarData) { // The evalType will be the type of the interface variable in SPIR-V. // The type of the variable used in the body of the function will still be - // `type`. + // `stageVarData.type`. QualType evalType = getTypeForSpirvStageVariable(stageVarData); const auto *builtinAttr = stageVarData.decl->getAttr(); diff --git a/tools/clang/test/CodeGenSPIRV/hs.bool.input.hlsl b/tools/clang/test/CodeGenSPIRV/hs.bool.input.hlsl new file mode 100644 index 0000000000..7f06b2bf87 --- /dev/null +++ b/tools/clang/test/CodeGenSPIRV/hs.bool.input.hlsl @@ -0,0 +1,47 @@ +// RUN: %dxc -T hs_6_0 -E Hull -fcgl %s -spirv | FileCheck %s + +struct ControlPoint +{ + bool b : MY_BOOL; +}; + +struct HullPatchOut { + float edge [3] : SV_TessFactor; + float inside : SV_InsideTessFactor; +}; + +// Check that the wrapper function correctly copies `v` as a parameter to Hull. +// CHECK: [[ld:%[0-9]+]] = OpLoad %_arr_uint_uint_3 %in_var_MY_BOOL +// CHECK: [[element1:%[0-9]+]] = OpCompositeExtract %uint [[ld:%[0-9]+]] 0 +// CHECK: [[bool1:%[0-9]+]] = OpINotEqual %bool [[element1]] %uint_0 +// CHECK: [[element2:%[0-9]+]] = OpCompositeExtract %uint [[ld:%[0-9]+]] 1 +// CHECK: [[bool2:%[0-9]+]] = OpINotEqual %bool [[element2]] %uint_0 +// CHECK: [[element3:%[0-9]+]] = OpCompositeExtract %uint [[ld:%[0-9]+]] 2 +// CHECK: [[bool3:%[0-9]+]] = OpINotEqual %bool [[element3]] %uint_0 +// CHECK: [[bool_array:%[0-9]+]] = OpCompositeConstruct %_arr_bool_uint_3 [[bool1]] [[bool2]] [[bool3]] +// CHECK: [[element1:%[0-9]+]] = OpCompositeExtract %bool [[bool_array]] 0 +// CHECK: [[cp1:%[0-9]+]] = OpCompositeConstruct %ControlPoint [[element1]] +// CHECK: [[element2:%[0-9]+]] = OpCompositeExtract %bool [[bool_array]] 1 +// CHECK: [[cp2:%[0-9]+]] = OpCompositeConstruct %ControlPoint [[element2]] +// CHECK: [[element3:%[0-9]+]] = OpCompositeExtract %bool [[bool_array]] 2 +// CHECK: [[cp3:%[0-9]+]] = OpCompositeConstruct %ControlPoint [[element3]] +// CHECK: [[v:%[0-9]+]] = OpCompositeConstruct %_arr_ControlPoint_uint_3 [[cp1]] [[cp2]] [[cp3]] +// CHECK: OpStore %param_var_v [[v]] +// CHECK: [[ret:%[0-9]+]] = OpFunctionCall %ControlPoint %src_Hull %param_var_v %param_var_id + +// Check that the return value is correctly copied to the output variable. +// CHECK: [[ret_bool:%[0-9]+]] = OpCompositeExtract %bool [[ret]] 0 +// CHECK: [[ret_int:%[0-9]+]] = OpSelect %uint [[ret_bool]] %uint_1 %uint_0 +// CHECK: [[out_var:%[0-9]+]] = OpAccessChain %_ptr_Output_uint %out_var_MY_BOOL +// CHECK: OpStore [[out_var]] [[ret_int]] + +[domain("tri")] +[partitioning("fractional_odd")] +[outputtopology("triangle_cw")] +[patchconstantfunc("HullConst")] +[outputcontrolpoints(3)] +ControlPoint Hull (InputPatch v, uint id : SV_OutputControlPointID) { return v[id]; } +HullPatchOut HullConst (InputPatch v) { return (HullPatchOut)0; } + +[domain("tri")] +float4 Domain (const OutputPatch vi) : SV_Position { return (vi[0].b ? 1 : 0); }