Skip to content

Commit

Permalink
Add knowledge of coop matrix.
Browse files Browse the repository at this point in the history
  • Loading branch information
s-perron committed Jun 6, 2024
1 parent ce46482 commit 826ea53
Showing 1 changed file with 15 additions and 0 deletions.
15 changes: 15 additions & 0 deletions source/opt/folding_rules.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1110,6 +1110,11 @@ FoldingRule MergeAddAddArithmetic() {
inst->opcode() == spv::Op::OpIAdd);
const analysis::Type* type =
context->get_type_mgr()->GetType(inst->type_id());

if (type->kind() == analysis::Type::Kind::kCooperativeMatrixKHR) {
return false;
}

analysis::ConstantManager* const_mgr = context->get_constant_mgr();
bool uses_float = HasFloatingPoint(type);
if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
Expand Down Expand Up @@ -1158,6 +1163,11 @@ FoldingRule MergeAddSubArithmetic() {
inst->opcode() == spv::Op::OpIAdd);
const analysis::Type* type =
context->get_type_mgr()->GetType(inst->type_id());

if (type->kind() == analysis::Type::Kind::kCooperativeMatrixKHR) {
return false;
}

analysis::ConstantManager* const_mgr = context->get_constant_mgr();
bool uses_float = HasFloatingPoint(type);
if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
Expand Down Expand Up @@ -1377,6 +1387,11 @@ FoldingRule MergeGenericAddSubArithmetic() {
inst->opcode() == spv::Op::OpIAdd);
const analysis::Type* type =
context->get_type_mgr()->GetType(inst->type_id());

if (type->kind() == analysis::Type::kCooperativeMatrixKHR) {
return false;
}

bool uses_float = HasFloatingPoint(type);
if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;

Expand Down

0 comments on commit 826ea53

Please sign in to comment.