Skip to content

Commit

Permalink
Add float tests, and a handle a couple more cases.
Browse files Browse the repository at this point in the history
  • Loading branch information
s-perron committed Jun 25, 2024
1 parent 6a61154 commit f1b287f
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 0 deletions.
20 changes: 20 additions & 0 deletions source/opt/folding_rules.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,11 @@ FoldingRule ReciprocalFDiv() {
analysis::ConstantManager* const_mgr = context->get_constant_mgr();
const analysis::Type* type =
context->get_type_mgr()->GetType(inst->type_id());

if (type->kind() == analysis::Type::kCooperativeMatrixKHR) {
return false;
}

if (!inst->IsFloatingPointFoldingAllowed()) return false;

uint32_t width = ElementWidth(type);
Expand Down Expand Up @@ -755,6 +760,11 @@ FoldingRule MergeMulDivArithmetic() {

const analysis::Type* type =
context->get_type_mgr()->GetType(inst->type_id());

if (type->kind() == analysis::Type::kCooperativeMatrixKHR) {
return false;
}

if (!inst->IsFloatingPointFoldingAllowed()) return false;

uint32_t width = ElementWidth(type);
Expand Down Expand Up @@ -873,6 +883,11 @@ FoldingRule MergeDivDivArithmetic() {
analysis::ConstantManager* const_mgr = context->get_constant_mgr();
const analysis::Type* type =
context->get_type_mgr()->GetType(inst->type_id());

if (type->kind() == analysis::Type::kCooperativeMatrixKHR) {
return false;
}

if (!inst->IsFloatingPointFoldingAllowed()) return false;

uint32_t width = ElementWidth(type);
Expand Down Expand Up @@ -946,6 +961,11 @@ FoldingRule MergeDivMulArithmetic() {

const analysis::Type* type =
context->get_type_mgr()->GetType(inst->type_id());

if (type->kind() == analysis::Type::kCooperativeMatrixKHR) {
return false;
}

if (!inst->IsFloatingPointFoldingAllowed()) return false;

uint32_t width = ElementWidth(type);
Expand Down
48 changes: 48 additions & 0 deletions test/opt/fold_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4753,6 +4753,54 @@ INSTANTIATE_TEST_SUITE_P(FloatRedundantFoldingTest, GeneralInstructionFoldingTes
"%2 = OpFDiv %half %half_1 %half_2\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 0),
// Test case 24: Don't fold OpFNegate for cooperative matrices.
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%2 = OpFNegate %float_coop_matrix %undef_float_coop_matrix\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 0),
// Test case 25: Don't fold OpIAdd for cooperative matrices.
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%2 = OpFAdd %float_coop_matrix %undef_float_coop_matrix %undef_float_coop_matrix\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 0),
// Test case 26: Don't fold OpISub for cooperative matrices.
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%2 = OpFSub %float_coop_matrix %undef_float_coop_matrix %undef_float_coop_matrix\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 0),
// Test case 27: Don't fold OpIMul for cooperative matrices.
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%2 = OpFMul %float_coop_matrix %undef_float_coop_matrix %undef_float_coop_matrix\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 0),
// Test case 28: Don't fold OpSDiv for cooperative matrices.
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%2 = OpFDiv %float_coop_matrix %undef_float_coop_matrix %undef_float_coop_matrix\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 0),
// Test case 29: Don't fold OpMatrixTimesScalar for cooperative matrices.
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%2 = OpMatrixTimesScalar %float_coop_matrix %undef_float_coop_matrix %float_3\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 0)
));

Expand Down

0 comments on commit f1b287f

Please sign in to comment.