Skip to content

Commit

Permalink
Add int tests, and a handle a couple more cases.
Browse files Browse the repository at this point in the history
  • Loading branch information
s-perron committed Jun 25, 2024
1 parent 86d31ab commit 6a61154
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 0 deletions.
10 changes: 10 additions & 0 deletions source/opt/folding_rules.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -696,6 +696,11 @@ FoldingRule MergeMulMulArithmetic() {
analysis::ConstantManager* const_mgr = context->get_constant_mgr();
const analysis::Type* type =
context->get_type_mgr()->GetType(inst->type_id());

if (type->kind() == analysis::Type::kCooperativeMatrixKHR) {
return false;
}

if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed())
return false;

Expand Down Expand Up @@ -823,6 +828,11 @@ FoldingRule MergeMulNegateArithmetic() {
analysis::ConstantManager* const_mgr = context->get_constant_mgr();
const analysis::Type* type =
context->get_type_mgr()->GetType(inst->type_id());

if (type->kind() == analysis::Type::kCooperativeMatrixKHR) {
return false;
}

bool uses_float = HasFloatingPoint(type);
if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;

Expand Down
64 changes: 64 additions & 0 deletions test/opt/fold_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,8 @@ OpCapability Float64
OpCapability Int8
OpCapability Int16
OpCapability Int64
OpCapability CooperativeMatrixKHR
OpExtension "SPV_KHR_cooperative_matrix"
%1 = OpExtInstImport "GLSL.std.450"
OpMemoryModel Logical GLSL450
OpEntryPoint Fragment %main "main"
Expand Down Expand Up @@ -434,6 +436,12 @@ OpName %main "main"
%ushort_0xBC00 = OpConstant %ushort 0xBC00
%short_0xBC00 = OpConstant %short 0xBC00
%int_arr_2_undef = OpUndef %int_arr_2
%int_coop_matrix = OpTypeCooperativeMatrixKHR %int %uint_3 %uint_3 %uint_32 %uint_0
%undef_int_coop_matrix = OpUndef %int_coop_matrix
%uint_coop_matrix = OpTypeCooperativeMatrixKHR %uint %uint_3 %uint_3 %uint_32 %uint_0
%undef_uint_coop_matrix = OpUndef %uint_coop_matrix
%float_coop_matrix = OpTypeCooperativeMatrixKHR %float %uint_3 %uint_3 %uint_32 %uint_0
%undef_float_coop_matrix = OpUndef %float_coop_matrix
)";

return header;
Expand Down Expand Up @@ -4148,6 +4156,62 @@ INSTANTIATE_TEST_SUITE_P(IntegerArithmeticTestCases, GeneralInstructionFoldingTe
"%2 = OpSLessThan %bool %long_0 %long_2\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 0),
// Test case 41: Don't fold OpSNegate for cooperative matrices.
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%2 = OpSNegate %int_coop_matrix %undef_int_coop_matrix\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 0),
// Test case 42: Don't fold OpIAdd for cooperative matrices.
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%2 = OpIAdd %int_coop_matrix %undef_int_coop_matrix %undef_int_coop_matrix\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 0),
// Test case 43: Don't fold OpISub for cooperative matrices.
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%2 = OpISub %int_coop_matrix %undef_int_coop_matrix %undef_int_coop_matrix\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 0),
// Test case 44: Don't fold OpIMul for cooperative matrices.
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%2 = OpIMul %int_coop_matrix %undef_int_coop_matrix %undef_int_coop_matrix\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 0),
// Test case 45: Don't fold OpSDiv for cooperative matrices.
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%2 = OpSDiv %int_coop_matrix %undef_int_coop_matrix %undef_int_coop_matrix\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 0),
// Test case 46: Don't fold OpUDiv for cooperative matrices.
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%2 = OpUDiv %uint_coop_matrix %undef_uint_coop_matrix %undef_uint_coop_matrix\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 0),
// Test case 47: Don't fold OpMatrixTimesScalar for cooperative matrices.
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%2 = OpMatrixTimesScalar %uint_coop_matrix %undef_uint_coop_matrix %uint_3\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 0)
));

Expand Down

0 comments on commit 6a61154

Please sign in to comment.