Skip to content

Commit

Permalink
Handle bool in HS input and outputs (microsoft#6129)
Browse files Browse the repository at this point in the history
Inputs and outputs that are boolean in HLSL cannot be bools in SPIR-V
because it is not allowed by the spec. So they need to be treated as
ints in the interface.

This was not tested with Hull shaders, where these input and outputs
must also be turned into arrays. This commit adds the code that will
handle casting the arrays between int and bool.

Fixes microsoft#3744
  • Loading branch information
s-perron authored Jan 9, 2024
1 parent eb4cec4 commit 849f8b8
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 14 deletions.
56 changes: 42 additions & 14 deletions tools/clang/lib/SPIRV/DeclResultIdMapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2766,6 +2766,23 @@ void DeclResultIdMapper::storeToShaderOutputVariable(
const StageVarDataBundle &stageVarData) {
SpirvInstruction *ptr = varInstr;

// Since boolean output stage variables are represented as unsigned
// integers, we must cast the value to uint before storing.
if (isBooleanStageIOVar(stageVarData.decl, stageVarData.type,
stageVarData.semantic->getKind(),
stageVarData.sigPoint->GetKind())) {
QualType finalType = varInstr->getAstResultType();
if (stageVarData.arraySize != 0) {
// We assume that we will only have to write to a single value of the
// array, so we have to cast to the element type of the array, and not the
// array type.
assert(stageVarData.invocationId.hasValue());
finalType = finalType->getAsArrayTypeUnsafe()->getElementType();
}
value = theEmitter.castToType(value, stageVarData.type, finalType,
stageVarData.decl->getLocation());
}

// Special handling of SV_TessFactor HS patch constant output.
// TessLevelOuter is always an array of size 4 in SPIR-V, but
// SV_TessFactor could be an array of size 2, 3, or 4 in HLSL. Only the
Expand Down Expand Up @@ -2831,16 +2848,6 @@ void DeclResultIdMapper::storeToShaderOutputVariable(
ptr->setStorageClass(spv::StorageClass::Output);
spvBuilder.createStore(ptr, value, stageVarData.decl->getLocation());
}
// Since boolean output stage variables are represented as unsigned
// integers, we must cast the value to uint before storing.
else if (isBooleanStageIOVar(stageVarData.decl, stageVarData.type,
stageVarData.semantic->getKind(),
stageVarData.sigPoint->GetKind())) {
value = theEmitter.castToType(value, stageVarData.type,
varInstr->getAstResultType(),
stageVarData.decl->getLocation());
spvBuilder.createStore(ptr, value, stageVarData.decl->getLocation());
}
// For all normal cases
else {
spvBuilder.createStore(ptr, value, stageVarData.decl->getLocation());
Expand Down Expand Up @@ -2983,9 +2990,30 @@ SpirvInstruction *DeclResultIdMapper::loadShaderInputVariable(
if (isBooleanStageIOVar(stageVarData.decl, stageVarData.type,
stageVarData.semantic->getKind(),
stageVarData.sigPoint->GetKind())) {
load = theEmitter.castToType(load, varInstr->getAstResultType(),
stageVarData.type,
stageVarData.decl->getLocation());

if (stageVarData.arraySize == 0) {
load = theEmitter.castToType(load, varInstr->getAstResultType(),
stageVarData.type,
stageVarData.decl->getLocation());
} else {
llvm::SmallVector<SpirvInstruction *, 8> fields;
SourceLocation loc = stageVarData.decl->getLocation();
QualType originalScalarType = varInstr->getAstResultType()
->castAsArrayTypeUnsafe()
->getElementType();
for (uint32_t idx = 0; idx < stageVarData.arraySize; ++idx) {
SpirvInstruction *field = spvBuilder.createCompositeExtract(
originalScalarType, load, {idx}, loc);
field = theEmitter.castToType(field, field->getAstResultType(),
stageVarData.type, loc);
fields.push_back(field);
}

QualType finalType = astContext.getConstantArrayType(
stageVarData.type, llvm::APInt(32, stageVarData.arraySize),
clang::ArrayType::Normal, 0);
load = spvBuilder.createCompositeConstruct(finalType, fields, loc);
}
}
return load;
}
Expand Down Expand Up @@ -3237,7 +3265,7 @@ SpirvVariable *DeclResultIdMapper::createSpirvInterfaceVariable(
const StageVarDataBundle &stageVarData) {
// The evalType will be the type of the interface variable in SPIR-V.
// The type of the variable used in the body of the function will still be
// `type`.
// `stageVarData.type`.
QualType evalType = getTypeForSpirvStageVariable(stageVarData);

const auto *builtinAttr = stageVarData.decl->getAttr<VKBuiltInAttr>();
Expand Down
47 changes: 47 additions & 0 deletions tools/clang/test/CodeGenSPIRV/hs.bool.input.hlsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
// RUN: %dxc -T hs_6_0 -E Hull -fcgl %s -spirv | FileCheck %s

struct ControlPoint
{
bool b : MY_BOOL;
};

struct HullPatchOut {
float edge [3] : SV_TessFactor;
float inside : SV_InsideTessFactor;
};

// Check that the wrapper function correctly copies `v` as a parameter to Hull.
// CHECK: [[ld:%[0-9]+]] = OpLoad %_arr_uint_uint_3 %in_var_MY_BOOL
// CHECK: [[element1:%[0-9]+]] = OpCompositeExtract %uint [[ld:%[0-9]+]] 0
// CHECK: [[bool1:%[0-9]+]] = OpINotEqual %bool [[element1]] %uint_0
// CHECK: [[element2:%[0-9]+]] = OpCompositeExtract %uint [[ld:%[0-9]+]] 1
// CHECK: [[bool2:%[0-9]+]] = OpINotEqual %bool [[element2]] %uint_0
// CHECK: [[element3:%[0-9]+]] = OpCompositeExtract %uint [[ld:%[0-9]+]] 2
// CHECK: [[bool3:%[0-9]+]] = OpINotEqual %bool [[element3]] %uint_0
// CHECK: [[bool_array:%[0-9]+]] = OpCompositeConstruct %_arr_bool_uint_3 [[bool1]] [[bool2]] [[bool3]]
// CHECK: [[element1:%[0-9]+]] = OpCompositeExtract %bool [[bool_array]] 0
// CHECK: [[cp1:%[0-9]+]] = OpCompositeConstruct %ControlPoint [[element1]]
// CHECK: [[element2:%[0-9]+]] = OpCompositeExtract %bool [[bool_array]] 1
// CHECK: [[cp2:%[0-9]+]] = OpCompositeConstruct %ControlPoint [[element2]]
// CHECK: [[element3:%[0-9]+]] = OpCompositeExtract %bool [[bool_array]] 2
// CHECK: [[cp3:%[0-9]+]] = OpCompositeConstruct %ControlPoint [[element3]]
// CHECK: [[v:%[0-9]+]] = OpCompositeConstruct %_arr_ControlPoint_uint_3 [[cp1]] [[cp2]] [[cp3]]
// CHECK: OpStore %param_var_v [[v]]
// CHECK: [[ret:%[0-9]+]] = OpFunctionCall %ControlPoint %src_Hull %param_var_v %param_var_id

// Check that the return value is correctly copied to the output variable.
// CHECK: [[ret_bool:%[0-9]+]] = OpCompositeExtract %bool [[ret]] 0
// CHECK: [[ret_int:%[0-9]+]] = OpSelect %uint [[ret_bool]] %uint_1 %uint_0
// CHECK: [[out_var:%[0-9]+]] = OpAccessChain %_ptr_Output_uint %out_var_MY_BOOL
// CHECK: OpStore [[out_var]] [[ret_int]]

[domain("tri")]
[partitioning("fractional_odd")]
[outputtopology("triangle_cw")]
[patchconstantfunc("HullConst")]
[outputcontrolpoints(3)]
ControlPoint Hull (InputPatch<ControlPoint,3> v, uint id : SV_OutputControlPointID) { return v[id]; }
HullPatchOut HullConst (InputPatch<ControlPoint,3> v) { return (HullPatchOut)0; }

[domain("tri")]
float4 Domain (const OutputPatch<ControlPoint,3> vi) : SV_Position { return (vi[0].b ? 1 : 0); }

0 comments on commit 849f8b8

Please sign in to comment.