diff --git a/source/opt/folding_rules.cpp b/source/opt/folding_rules.cpp index a928fa21e3d..41f535d85d7 100644 --- a/source/opt/folding_rules.cpp +++ b/source/opt/folding_rules.cpp @@ -313,6 +313,11 @@ FoldingRule ReciprocalFDiv() { analysis::ConstantManager* const_mgr = context->get_constant_mgr(); const analysis::Type* type = context->get_type_mgr()->GetType(inst->type_id()); + + if (type->kind() == analysis::Type::kCooperativeMatrixKHR) { + return false; + } + if (!inst->IsFloatingPointFoldingAllowed()) return false; uint32_t width = ElementWidth(type); @@ -755,6 +760,11 @@ FoldingRule MergeMulDivArithmetic() { const analysis::Type* type = context->get_type_mgr()->GetType(inst->type_id()); + + if (type->kind() == analysis::Type::kCooperativeMatrixKHR) { + return false; + } + if (!inst->IsFloatingPointFoldingAllowed()) return false; uint32_t width = ElementWidth(type); @@ -873,6 +883,11 @@ FoldingRule MergeDivDivArithmetic() { analysis::ConstantManager* const_mgr = context->get_constant_mgr(); const analysis::Type* type = context->get_type_mgr()->GetType(inst->type_id()); + + if (type->kind() == analysis::Type::kCooperativeMatrixKHR) { + return false; + } + if (!inst->IsFloatingPointFoldingAllowed()) return false; uint32_t width = ElementWidth(type); @@ -946,6 +961,11 @@ FoldingRule MergeDivMulArithmetic() { const analysis::Type* type = context->get_type_mgr()->GetType(inst->type_id()); + + if (type->kind() == analysis::Type::kCooperativeMatrixKHR) { + return false; + } + if (!inst->IsFloatingPointFoldingAllowed()) return false; uint32_t width = ElementWidth(type); diff --git a/test/opt/fold_test.cpp b/test/opt/fold_test.cpp index 3a0884d183a..e2d9d7cc185 100644 --- a/test/opt/fold_test.cpp +++ b/test/opt/fold_test.cpp @@ -4753,6 +4753,54 @@ INSTANTIATE_TEST_SUITE_P(FloatRedundantFoldingTest, GeneralInstructionFoldingTes "%2 = OpFDiv %half %half_1 %half_2\n" + "OpReturn\n" + "OpFunctionEnd", + 2, 0), + // Test case 24: Don't fold OpFNegate for cooperative matrices. + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpFNegate %float_coop_matrix %undef_float_coop_matrix\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0), + // Test case 25: Don't fold OpIAdd for cooperative matrices. + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpFAdd %float_coop_matrix %undef_float_coop_matrix %undef_float_coop_matrix\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0), + // Test case 26: Don't fold OpISub for cooperative matrices. + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpFSub %float_coop_matrix %undef_float_coop_matrix %undef_float_coop_matrix\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0), + // Test case 27: Don't fold OpIMul for cooperative matrices. + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpFMul %float_coop_matrix %undef_float_coop_matrix %undef_float_coop_matrix\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0), + // Test case 28: Don't fold OpSDiv for cooperative matrices. + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpFDiv %float_coop_matrix %undef_float_coop_matrix %undef_float_coop_matrix\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0), + // Test case 29: Don't fold OpMatrixTimesScalar for cooperative matrices. + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpMatrixTimesScalar %float_coop_matrix %undef_float_coop_matrix %float_3\n" + + "OpReturn\n" + + "OpFunctionEnd", 2, 0) ));