diff --git a/source/opt/folding_rules.cpp b/source/opt/folding_rules.cpp index 24979671f27..e07df0cfea2 100644 --- a/source/opt/folding_rules.cpp +++ b/source/opt/folding_rules.cpp @@ -1110,6 +1110,11 @@ FoldingRule MergeAddAddArithmetic() { inst->opcode() == spv::Op::OpIAdd); const analysis::Type* type = context->get_type_mgr()->GetType(inst->type_id()); + + if (type->kind() == analysis::Type::Kind::kCooperativeMatrixKHR) { + return false; + } + analysis::ConstantManager* const_mgr = context->get_constant_mgr(); bool uses_float = HasFloatingPoint(type); if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false; @@ -1158,6 +1163,11 @@ FoldingRule MergeAddSubArithmetic() { inst->opcode() == spv::Op::OpIAdd); const analysis::Type* type = context->get_type_mgr()->GetType(inst->type_id()); + + if (type->kind() == analysis::Type::Kind::kCooperativeMatrixKHR) { + return false; + } + analysis::ConstantManager* const_mgr = context->get_constant_mgr(); bool uses_float = HasFloatingPoint(type); if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false; @@ -1377,6 +1387,11 @@ FoldingRule MergeGenericAddSubArithmetic() { inst->opcode() == spv::Op::OpIAdd); const analysis::Type* type = context->get_type_mgr()->GetType(inst->type_id()); + + if (type->kind() == analysis::Type::kCooperativeMatrixKHR) { + return false; + } + bool uses_float = HasFloatingPoint(type); if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;