diff --git a/source/opt/folding_rules.cpp b/source/opt/folding_rules.cpp index 5f83669940..f41ef868a3 100644 --- a/source/opt/folding_rules.cpp +++ b/source/opt/folding_rules.cpp @@ -1607,27 +1607,26 @@ bool CompositeConstructFeedingExtract( } // Walks the indexes chain from |start| to |end| of an OpCompositeInsert or -// OpCompositeExtract instruction, and returns the type of the final element -// being accessed. -const analysis::Type* GetElementType(uint32_t type_id, - Instruction::iterator start, - Instruction::iterator end, - const analysis::TypeManager* type_mgr) { - const analysis::Type* type = type_mgr->GetType(type_id); +// OpCompositeExtract instruction, and returns the type id of the final element +// being accessed. Returns 0 if a valid type could not be found. +uint32_t GetElementType(uint32_t type_id, Instruction::iterator start, + Instruction::iterator end, + const analysis::DefUseManager* def_use_manager) { for (auto index : make_range(std::move(start), std::move(end))) { + const Instruction* type_inst = def_use_manager->GetDef(type_id); assert(index.type == SPV_OPERAND_TYPE_LITERAL_INTEGER && index.words.size() == 1); - if (auto* array_type = type->AsArray()) { - type = array_type->element_type(); - } else if (auto* matrix_type = type->AsMatrix()) { - type = matrix_type->element_type(); - } else if (auto* struct_type = type->AsStruct()) { - type = struct_type->element_types()[index.words[0]]; + if (type_inst->opcode() == spv::Op::OpTypeArray) { + type_id = type_inst->GetSingleWordInOperand(0); + } else if (type_inst->opcode() == spv::Op::OpTypeMatrix) { + type_id = type_inst->GetSingleWordInOperand(0); + } else if (type_inst->opcode() == spv::Op::OpTypeStruct) { + type_id = type_inst->GetSingleWordInOperand(index.words[0]); } else { - type = nullptr; + return 0; } } - return type; + return type_id; } // Returns true of |inst_1| and |inst_2| have the same indexes that will be used @@ -1712,16 +1711,11 @@ bool CompositeExtractFeedingConstruct( // The last check it to see that the object being extracted from is the // correct type. Instruction* original_inst = def_use_mgr->GetDef(original_id); - analysis::TypeManager* type_mgr = context->get_type_mgr(); - const analysis::Type* original_type = + uint32_t original_type_id = GetElementType(original_inst->type_id(), first_element_inst->begin() + 3, - first_element_inst->end() - 1, type_mgr); - - if (original_type == nullptr) { - return false; - } + first_element_inst->end() - 1, def_use_mgr); - if (inst->type_id() != type_mgr->GetId(original_type)) { + if (inst->type_id() != original_type_id) { return false; } @@ -2015,9 +2009,11 @@ bool DoInsertedValuesCoverEntireObject( // inserted by the OpCompositeInsert instruction |inst|. const analysis::Type* GetContainerType(Instruction* inst) { assert(inst->opcode() == spv::Op::OpCompositeInsert); + analysis::DefUseManager* def_use_manager = inst->context()->get_def_use_mgr(); + uint32_t container_type_id = GetElementType( + inst->type_id(), inst->begin() + 4, inst->end() - 1, def_use_manager); analysis::TypeManager* type_mgr = inst->context()->get_type_mgr(); - return GetElementType(inst->type_id(), inst->begin() + 4, inst->end() - 1, - type_mgr); + return type_mgr->GetType(container_type_id); } // Returns an OpCompositeConstruct instruction that build an object with diff --git a/test/opt/fold_test.cpp b/test/opt/fold_test.cpp index 255449dbbf..35828ab22f 100644 --- a/test/opt/fold_test.cpp +++ b/test/opt/fold_test.cpp @@ -7827,7 +7827,21 @@ ::testing::Values( "%5 = OpCompositeInsert %int_arr_2 %int_1 %4 1\n" + "OpReturn\n" + "OpFunctionEnd", - 5, true) + 5, true), + // Test case 19: Don't fold for isomorphic structs + InstructionFoldingCase( + Header() + + "%structA = OpTypeStruct %ulong\n" + + "%structB = OpTypeStruct %ulong\n" + + "%structC = OpTypeStruct %structB\n" + + "%struct_a_undef = OpUndef %structA\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%3 = OpCompositeExtract %ulong %struct_a_undef 0\n" + + "%4 = OpCompositeConstruct %structB %3\n" + + "OpReturn\n" + + "OpFunctionEnd", + 4, false) )); INSTANTIATE_TEST_SUITE_P(DotProductMatchingTest, MatchingInstructionFoldingTest,