Skip to content

Commit

Permalink
Add knowledge of cooperative matrices
Browse files Browse the repository at this point in the history
Some optimizations are not aware of cooperative matrices, and either do
nothing or assert. This commits fixes that up.
  • Loading branch information
s-perron committed Jun 6, 2024
1 parent 826ea53 commit 4a6d55e
Show file tree
Hide file tree
Showing 7 changed files with 34 additions and 5 deletions.
3 changes: 2 additions & 1 deletion source/opt/aggressive_dead_code_elim_pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1004,7 +1004,8 @@ void AggressiveDCEPass::InitExtensions() {
"SPV_NV_bindless_texture",
"SPV_EXT_shader_atomic_float_add",
"SPV_EXT_fragment_shader_interlock",
"SPV_NV_compute_shader_derivatives"
"SPV_NV_compute_shader_derivatives",
"SPV_KHR_cooperative_matrix"
});
// clang-format on
}
Expand Down
25 changes: 25 additions & 0 deletions source/opt/folding_rules.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,11 @@ FoldingRule MergeNegateMulDivArithmetic() {
analysis::ConstantManager* const_mgr = context->get_constant_mgr();
const analysis::Type* type =
context->get_type_mgr()->GetType(inst->type_id());

if (type->kind() == analysis::Type::kCooperativeMatrixKHR) {
return false;
}

if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed())
return false;

Expand Down Expand Up @@ -449,6 +454,11 @@ FoldingRule MergeNegateAddSubArithmetic() {
analysis::ConstantManager* const_mgr = context->get_constant_mgr();
const analysis::Type* type =
context->get_type_mgr()->GetType(inst->type_id());

if (type->kind() == analysis::Type::kCooperativeMatrixKHR) {
return false;
}

if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed())
return false;

Expand Down Expand Up @@ -1062,6 +1072,11 @@ FoldingRule MergeSubNegateArithmetic() {
analysis::ConstantManager* const_mgr = context->get_constant_mgr();
const analysis::Type* type =
context->get_type_mgr()->GetType(inst->type_id());

if (type->kind() == analysis::Type::kCooperativeMatrixKHR) {
return false;
}

bool uses_float = HasFloatingPoint(type);
if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;

Expand Down Expand Up @@ -1228,6 +1243,11 @@ FoldingRule MergeSubAddArithmetic() {
inst->opcode() == spv::Op::OpISub);
const analysis::Type* type =
context->get_type_mgr()->GetType(inst->type_id());

if (type->kind() == analysis::Type::kCooperativeMatrixKHR) {
return false;
}

analysis::ConstantManager* const_mgr = context->get_constant_mgr();
bool uses_float = HasFloatingPoint(type);
if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
Expand Down Expand Up @@ -1294,6 +1314,11 @@ FoldingRule MergeSubSubArithmetic() {
inst->opcode() == spv::Op::OpISub);
const analysis::Type* type =
context->get_type_mgr()->GetType(inst->type_id());

if (type->kind() == analysis::Type::kCooperativeMatrixKHR) {
return false;
}

analysis::ConstantManager* const_mgr = context->get_constant_mgr();
bool uses_float = HasFloatingPoint(type);
if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
Expand Down
2 changes: 1 addition & 1 deletion source/opt/local_access_chain_convert_pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,7 @@ void LocalAccessChainConvertPass::InitExtensions() {
"SPV_KHR_fragment_shader_barycentric", "SPV_KHR_vulkan_memory_model",
"SPV_NV_bindless_texture", "SPV_EXT_shader_atomic_float_add",
"SPV_EXT_fragment_shader_interlock",
"SPV_NV_compute_shader_derivatives"});
"SPV_NV_compute_shader_derivatives", "SPV_KHR_cooperative_matrix"});
}

bool LocalAccessChainConvertPass::AnyIndexIsOutOfBounds(
Expand Down
3 changes: 2 additions & 1 deletion source/opt/local_single_block_elim_pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,8 @@ void LocalSingleBlockLoadStoreElimPass::InitExtensions() {
"SPV_NV_bindless_texture",
"SPV_EXT_shader_atomic_float_add",
"SPV_EXT_fragment_shader_interlock",
"SPV_NV_compute_shader_derivatives"});
"SPV_NV_compute_shader_derivatives",
"SPV_KHR_cooperative_matrix"});
}

} // namespace opt
Expand Down
3 changes: 2 additions & 1 deletion source/opt/local_single_store_elim_pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,8 @@ void LocalSingleStoreElimPass::InitExtensionAllowList() {
"SPV_NV_bindless_texture",
"SPV_EXT_shader_atomic_float_add",
"SPV_EXT_fragment_shader_interlock",
"SPV_NV_compute_shader_derivatives"});
"SPV_NV_compute_shader_derivatives",
"SPV_KHR_cooperative_matrix"});
}
bool LocalSingleStoreElimPass::ProcessVariable(Instruction* var_inst) {
std::vector<Instruction*> users;
Expand Down
1 change: 1 addition & 0 deletions source/opt/mem_pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ bool MemPass::IsBaseTargetType(const Instruction* typeInst) const {
case spv::Op::OpTypeSampler:
case spv::Op::OpTypeSampledImage:
case spv::Op::OpTypePointer:
case spv::Op::OpTypeCooperativeMatrixKHR:
return true;
default:
break;
Expand Down
2 changes: 1 addition & 1 deletion source/opt/ssa_rewrite_pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
// Debug logging (0: Off, 1-N: Verbosity level). Replace this with the
// implementation done for
// https://github.com/KhronosGroup/SPIRV-Tools/issues/1351
// #define SSA_REWRITE_DEBUGGING_LEVEL 3
#define SSA_REWRITE_DEBUGGING_LEVEL 3

#ifdef SSA_REWRITE_DEBUGGING_LEVEL
#include <ostream>
Expand Down

0 comments on commit 4a6d55e

Please sign in to comment.