From 4a6d55e764ce88038043ad64406782cd28ad935b Mon Sep 17 00:00:00 2001 From: Steven Perron Date: Wed, 5 Jun 2024 14:42:26 -0400 Subject: [PATCH] Add knowledge of cooperative matrices Some optimizations are not aware of cooperative matrices, and either do nothing or assert. This commits fixes that up. --- source/opt/aggressive_dead_code_elim_pass.cpp | 3 ++- source/opt/folding_rules.cpp | 25 +++++++++++++++++++ .../opt/local_access_chain_convert_pass.cpp | 2 +- source/opt/local_single_block_elim_pass.cpp | 3 ++- source/opt/local_single_store_elim_pass.cpp | 3 ++- source/opt/mem_pass.cpp | 1 + source/opt/ssa_rewrite_pass.cpp | 2 +- 7 files changed, 34 insertions(+), 5 deletions(-) diff --git a/source/opt/aggressive_dead_code_elim_pass.cpp b/source/opt/aggressive_dead_code_elim_pass.cpp index 4737da5f9cf..e9c831e544f 100644 --- a/source/opt/aggressive_dead_code_elim_pass.cpp +++ b/source/opt/aggressive_dead_code_elim_pass.cpp @@ -1004,7 +1004,8 @@ void AggressiveDCEPass::InitExtensions() { "SPV_NV_bindless_texture", "SPV_EXT_shader_atomic_float_add", "SPV_EXT_fragment_shader_interlock", - "SPV_NV_compute_shader_derivatives" + "SPV_NV_compute_shader_derivatives", + "SPV_KHR_cooperative_matrix" }); // clang-format on } diff --git a/source/opt/folding_rules.cpp b/source/opt/folding_rules.cpp index e07df0cfea2..3b2b9b0b11d 100644 --- a/source/opt/folding_rules.cpp +++ b/source/opt/folding_rules.cpp @@ -388,6 +388,11 @@ FoldingRule MergeNegateMulDivArithmetic() { analysis::ConstantManager* const_mgr = context->get_constant_mgr(); const analysis::Type* type = context->get_type_mgr()->GetType(inst->type_id()); + + if (type->kind() == analysis::Type::kCooperativeMatrixKHR) { + return false; + } + if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed()) return false; @@ -449,6 +454,11 @@ FoldingRule MergeNegateAddSubArithmetic() { analysis::ConstantManager* const_mgr = context->get_constant_mgr(); const analysis::Type* type = context->get_type_mgr()->GetType(inst->type_id()); + + if (type->kind() == analysis::Type::kCooperativeMatrixKHR) { + return false; + } + if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed()) return false; @@ -1062,6 +1072,11 @@ FoldingRule MergeSubNegateArithmetic() { analysis::ConstantManager* const_mgr = context->get_constant_mgr(); const analysis::Type* type = context->get_type_mgr()->GetType(inst->type_id()); + + if (type->kind() == analysis::Type::kCooperativeMatrixKHR) { + return false; + } + bool uses_float = HasFloatingPoint(type); if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false; @@ -1228,6 +1243,11 @@ FoldingRule MergeSubAddArithmetic() { inst->opcode() == spv::Op::OpISub); const analysis::Type* type = context->get_type_mgr()->GetType(inst->type_id()); + + if (type->kind() == analysis::Type::kCooperativeMatrixKHR) { + return false; + } + analysis::ConstantManager* const_mgr = context->get_constant_mgr(); bool uses_float = HasFloatingPoint(type); if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false; @@ -1294,6 +1314,11 @@ FoldingRule MergeSubSubArithmetic() { inst->opcode() == spv::Op::OpISub); const analysis::Type* type = context->get_type_mgr()->GetType(inst->type_id()); + + if (type->kind() == analysis::Type::kCooperativeMatrixKHR) { + return false; + } + analysis::ConstantManager* const_mgr = context->get_constant_mgr(); bool uses_float = HasFloatingPoint(type); if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false; diff --git a/source/opt/local_access_chain_convert_pass.cpp b/source/opt/local_access_chain_convert_pass.cpp index 7ba75cb7a42..fd3b612cd26 100644 --- a/source/opt/local_access_chain_convert_pass.cpp +++ b/source/opt/local_access_chain_convert_pass.cpp @@ -429,7 +429,7 @@ void LocalAccessChainConvertPass::InitExtensions() { "SPV_KHR_fragment_shader_barycentric", "SPV_KHR_vulkan_memory_model", "SPV_NV_bindless_texture", "SPV_EXT_shader_atomic_float_add", "SPV_EXT_fragment_shader_interlock", - "SPV_NV_compute_shader_derivatives"}); + "SPV_NV_compute_shader_derivatives", "SPV_KHR_cooperative_matrix"}); } bool LocalAccessChainConvertPass::AnyIndexIsOutOfBounds( diff --git a/source/opt/local_single_block_elim_pass.cpp b/source/opt/local_single_block_elim_pass.cpp index d7a9295e846..dd0b594a9d1 100644 --- a/source/opt/local_single_block_elim_pass.cpp +++ b/source/opt/local_single_block_elim_pass.cpp @@ -291,7 +291,8 @@ void LocalSingleBlockLoadStoreElimPass::InitExtensions() { "SPV_NV_bindless_texture", "SPV_EXT_shader_atomic_float_add", "SPV_EXT_fragment_shader_interlock", - "SPV_NV_compute_shader_derivatives"}); + "SPV_NV_compute_shader_derivatives", + "SPV_KHR_cooperative_matrix"}); } } // namespace opt diff --git a/source/opt/local_single_store_elim_pass.cpp b/source/opt/local_single_store_elim_pass.cpp index 7cd6b0eb476..aa7a7569550 100644 --- a/source/opt/local_single_store_elim_pass.cpp +++ b/source/opt/local_single_store_elim_pass.cpp @@ -141,7 +141,8 @@ void LocalSingleStoreElimPass::InitExtensionAllowList() { "SPV_NV_bindless_texture", "SPV_EXT_shader_atomic_float_add", "SPV_EXT_fragment_shader_interlock", - "SPV_NV_compute_shader_derivatives"}); + "SPV_NV_compute_shader_derivatives", + "SPV_KHR_cooperative_matrix"}); } bool LocalSingleStoreElimPass::ProcessVariable(Instruction* var_inst) { std::vector users; diff --git a/source/opt/mem_pass.cpp b/source/opt/mem_pass.cpp index 9972c4f75f5..6e59bf0f8d8 100644 --- a/source/opt/mem_pass.cpp +++ b/source/opt/mem_pass.cpp @@ -43,6 +43,7 @@ bool MemPass::IsBaseTargetType(const Instruction* typeInst) const { case spv::Op::OpTypeSampler: case spv::Op::OpTypeSampledImage: case spv::Op::OpTypePointer: + case spv::Op::OpTypeCooperativeMatrixKHR: return true; default: break; diff --git a/source/opt/ssa_rewrite_pass.cpp b/source/opt/ssa_rewrite_pass.cpp index 3eb4ec3f8e6..23a984dfdd7 100644 --- a/source/opt/ssa_rewrite_pass.cpp +++ b/source/opt/ssa_rewrite_pass.cpp @@ -52,7 +52,7 @@ // Debug logging (0: Off, 1-N: Verbosity level). Replace this with the // implementation done for // https://github.com/KhronosGroup/SPIRV-Tools/issues/1351 -// #define SSA_REWRITE_DEBUGGING_LEVEL 3 +#define SSA_REWRITE_DEBUGGING_LEVEL 3 #ifdef SSA_REWRITE_DEBUGGING_LEVEL #include