diff --git a/source/opt/const_folding_rules.cpp b/source/opt/const_folding_rules.cpp index e676974c8cd..3a0186d7a4a 100644 --- a/source/opt/const_folding_rules.cpp +++ b/source/opt/const_folding_rules.cpp @@ -676,7 +676,6 @@ ConstantFoldingRule FoldUnaryOp(UnaryScalarFoldingRule scalar_rule) { return [scalar_rule](IRContext* context, Instruction* inst, const std::vector& constants) -> const analysis::Constant* { - analysis::ConstantManager* const_mgr = context->get_constant_mgr(); analysis::TypeManager* type_mgr = context->get_type_mgr(); const analysis::Type* result_type = type_mgr->GetType(inst->type_id()); @@ -716,6 +715,63 @@ ConstantFoldingRule FoldUnaryOp(UnaryScalarFoldingRule scalar_rule) { }; } +// Returns a |ConstantFoldingRule| that folds binary scalar ops +// using |scalar_rule| and unary vectors ops by applying +// |scalar_rule| to the elements of the vector. The |ConstantFoldingRule| +// that is returned assumes that |constants| contains 2 entries. If they are +// not |nullptr|, then their type is either |Float| or |Integer| or a |Vector| +// whose element type is |Float| or |Integer|. +ConstantFoldingRule FoldBinaryOp(BinaryScalarFoldingRule scalar_rule) { + return [scalar_rule](IRContext* context, Instruction* inst, + const std::vector& constants) + -> const analysis::Constant* { + analysis::ConstantManager* const_mgr = context->get_constant_mgr(); + analysis::TypeManager* type_mgr = context->get_type_mgr(); + const analysis::Type* result_type = type_mgr->GetType(inst->type_id()); + const analysis::Vector* vector_type = result_type->AsVector(); + + const analysis::Constant* arg1 = + (inst->opcode() == spv::Op::OpExtInst) ? constants[1] : constants[0]; + const analysis::Constant* arg2 = + (inst->opcode() == spv::Op::OpExtInst) ? constants[2] : constants[1]; + + if (arg1 == nullptr) { + return nullptr; + } + if (arg2 == nullptr) { + return nullptr; + } + + if (vector_type != nullptr) { + std::vector a_components; + std::vector b_components; + std::vector results_components; + + a_components = arg1->GetVectorComponents(const_mgr); + b_components = arg2->GetVectorComponents(const_mgr); + + // Fold each component of the vector. + for (uint32_t i = 0; i < a_components.size(); ++i) { + results_components.push_back(scalar_rule(vector_type->element_type(), + a_components[i], + b_components[i], const_mgr)); + if (results_components[i] == nullptr) { + return nullptr; + } + } + + // Build the constant object and return it. + std::vector ids; + for (const analysis::Constant* member : results_components) { + ids.push_back(const_mgr->GetDefiningInstruction(member)->result_id()); + } + return const_mgr->GetConstant(vector_type, ids); + } else { + return scalar_rule(result_type, arg1, arg2, const_mgr); + } + }; +} + // Returns a |ConstantFoldingRule| that folds unary floating point scalar ops // using |scalar_rule| and unary float point vectors ops by applying // |scalar_rule| to the elements of the vector. The |ConstantFoldingRule| @@ -1587,6 +1643,60 @@ BinaryScalarFoldingRule FoldFTranscendentalBinary(double (*fp)(double, return nullptr; }; } + +template +BinaryScalarFoldingRule FoldBinaryUnsignedIntegerOperation( + uint64_t (*op)(uint64_t, uint64_t)) { + return + [op](const analysis::Type* result_type, const analysis::Constant* a, + const analysis::Constant* b, + analysis::ConstantManager* const_mgr) -> const analysis::Constant* { + assert(result_type != nullptr && a != nullptr); + const analysis::Integer* integer_type = a->type()->AsInteger(); + assert(integer_type != nullptr); + assert(integer_type == result_type->AsInteger()); + assert(integer_type == b->type()->AsInteger()); + + // In SPIR-V, the signedness of the operands is determined by the + // opcode, and not by the type of the operands. This is why we use the + // template argument to determine how to interpret the operands. + uint64_t ia = + (IsSigned ? a->GetSignExtendedValue() : a->GetZeroExtendedValue()); + uint64_t ib = + (IsSigned ? b->GetSignExtendedValue() : b->GetZeroExtendedValue()); + uint64_t result = op(ia, ib); + + std::vector words; + if (integer_type->width() == 64) { + // In the 64-bit case, two words are needed to represent the value. + words = {static_cast(result), + static_cast(result >> 32)}; + } else { + // In all other cases, only a single word is needed. + assert(integer_type->width() <= 32); + + uint64_t mask_for_last_bit = 1ull << (integer_type->width() - 1); + uint64_t significant_bits = (mask_for_last_bit << 1) - 1ull; + if (integer_type->IsSigned()) { + // Sign-extend the most significant bit. + if (result & mask_for_last_bit) { + // Set upper bits to 1 + result |= ~significant_bits; + } else { + // Clear the upper bits + result &= significant_bits; + } + } else { + // Mask any bits largest than the width. + result &= (1ull << integer_type->width()) - 1ull; + } + + words = {static_cast(result)}; + } + return const_mgr->GetConstant(result_type, words); + }; +} + } // namespace void ConstantFoldingRules::AddFoldingRules() { @@ -1662,6 +1772,46 @@ void ConstantFoldingRules::AddFoldingRules() { rules_[spv::Op::OpSNegate].push_back(FoldSNegate()); rules_[spv::Op::OpQuantizeToF16].push_back(FoldQuantizeToF16()); + rules_[spv::Op::OpIAdd].push_back( + FoldBinaryOp(FoldBinaryUnsignedIntegerOperation( + [](uint64_t a, uint64_t b) { return a + b; }))); + rules_[spv::Op::OpISub].push_back( + FoldBinaryOp(FoldBinaryUnsignedIntegerOperation( + [](uint64_t a, uint64_t b) { return a - b; }))); + rules_[spv::Op::OpIMul].push_back( + FoldBinaryOp(FoldBinaryUnsignedIntegerOperation( + [](uint64_t a, uint64_t b) { return a * b; }))); + rules_[spv::Op::OpUDiv].push_back( + FoldBinaryOp(FoldBinaryUnsignedIntegerOperation( + [](uint64_t a, uint64_t b) { return (b != 0 ? a / b : 0); }))); + rules_[spv::Op::OpSDiv].push_back(FoldBinaryOp( + FoldBinaryUnsignedIntegerOperation([](uint64_t a, uint64_t b) { + return (b != 0 ? static_cast(static_cast(a) / + static_cast(b)) + : 0); + }))); + rules_[spv::Op::OpUMod].push_back( + FoldBinaryOp(FoldBinaryUnsignedIntegerOperation( + [](uint64_t a, uint64_t b) { return (b != 0 ? a % b : 0); }))); + + rules_[spv::Op::OpSRem].push_back(FoldBinaryOp( + FoldBinaryUnsignedIntegerOperation([](uint64_t a, uint64_t b) { + return (b != 0 ? static_cast(static_cast(a) % + static_cast(b)) + : 0); + }))); + + rules_[spv::Op::OpSMod].push_back(FoldBinaryOp( + FoldBinaryUnsignedIntegerOperation([](uint64_t a, uint64_t b) { + if (b == 0) return static_cast(0ull); + + int64_t signed_a = static_cast(a); + int64_t signed_b = static_cast(b); + int64_t result = signed_a % signed_b; + if ((signed_b < 0) != (result < 0)) result += signed_b; + return static_cast(result); + }))); + // Add rules for GLSLstd450 FeatureManager* feature_manager = context_->get_feature_mgr(); uint32_t ext_inst_glslstd450_id = diff --git a/test/opt/fold_test.cpp b/test/opt/fold_test.cpp index df4db39a9dd..6b705c61716 100644 --- a/test/opt/fold_test.cpp +++ b/test/opt/fold_test.cpp @@ -273,6 +273,7 @@ OpName %main "main" %short_0 = OpConstant %short 0 %short_2 = OpConstant %short 2 %short_3 = OpConstant %short 3 +%short_n5 = OpConstant %short -5 %ubyte_1 = OpConstant %ubyte 1 %byte_n1 = OpConstant %byte -1 %100 = OpConstant %int 0 ; Need a def with an numerical id to define id maps. @@ -292,7 +293,11 @@ OpName %main "main" %long_1 = OpConstant %long 1 %long_2 = OpConstant %long 2 %long_3 = OpConstant %long 3 +%long_n3 = OpConstant %long -3 +%long_7 = OpConstant %long 7 +%long_n7 = OpConstant %long -7 %long_10 = OpConstant %long 10 +%long_n4611686018427387904 = OpConstant %long -4611686018427387904 %long_4611686018427387904 = OpConstant %long 4611686018427387904 %long_n1 = OpConstant %long -1 %long_n3689348814741910323 = OpConstant %long -3689348814741910323 @@ -1019,10 +1024,185 @@ INSTANTIATE_TEST_SUITE_P(TestCase, IntegerInstructionFoldingTest, "%2 = OpSNegate %ushort %ushort_0x4400\n" + "OpReturn\n" + "OpFunctionEnd", - 2, 0xBC00 /* expected to be zero extended. */) + 2, 0xBC00 /* expected to be zero extended. */), + // Test case 67: Fold 2 + 3 (short) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpIAdd %short %short_2 %short_3\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 5), + // Test case 68: Fold 2 + -5 (short) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpIAdd %short %short_2 %short_n5\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, -3) )); // clang-format on +using LongIntegerInstructionFoldingTest = + ::testing::TestWithParam>; + +TEST_P(LongIntegerInstructionFoldingTest, Case) { + const auto& tc = GetParam(); + + // Build module. + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, tc.test_body, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + ASSERT_NE(nullptr, context); + + // Fold the instruction to test. + analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr(); + Instruction* inst = def_use_mgr->GetDef(tc.id_to_fold); + bool succeeded = context->get_instruction_folder().FoldInstruction(inst); + + // Make sure the instruction folded as expected. + EXPECT_TRUE(succeeded); + if (inst != nullptr) { + EXPECT_EQ(inst->opcode(), spv::Op::OpCopyObject); + inst = def_use_mgr->GetDef(inst->GetSingleWordInOperand(0)); + EXPECT_EQ(inst->opcode(), spv::Op::OpConstant); + analysis::ConstantManager* const_mrg = context->get_constant_mgr(); + const analysis::Constant* constant = const_mrg->GetConstantFromInst(inst); + // We expect to see either integer types or 16-bit float types here. + ASSERT_NE(constant->AsIntConstant(), nullptr); + ASSERT_EQ(constant->type()->AsInteger()->width(), 64); + + const analysis::ScalarConstant* result = + const_mrg->GetConstantFromInst(inst)->AsScalarConstant(); + ASSERT_NE(result, nullptr); + if (result != nullptr) { + EXPECT_EQ(result->GetU64BitValue(), tc.expected_result); + } + } +} + +INSTANTIATE_TEST_SUITE_P( + TestCase, LongIntegerInstructionFoldingTest, + ::testing::Values( + // Test case 0: fold 1+4611686018427387904 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_int Function\n" + + "%load = OpLoad %int %n\n" + + "%2 = OpIAdd %long %long_1 %long_4611686018427387904\n" + + "OpReturn\n" + "OpFunctionEnd", + 2, 1 + 4611686018427387904), + // Test case 1: fold 1-4611686018427387904 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_int Function\n" + + "%load = OpLoad %int %n\n" + + "%2 = OpISub %long %long_1 %long_4611686018427387904\n" + + "OpReturn\n" + "OpFunctionEnd", + 2, 1 - 4611686018427387904), + // Test case 2: fold 2*4611686018427387904 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_int Function\n" + + "%load = OpLoad %int %n\n" + + "%2 = OpIMul %long %long_2 %long_4611686018427387904\n" + + "OpReturn\n" + "OpFunctionEnd", + 2, 2 * 4611686018427387904), + // Test case 3: fold 4611686018427387904/2 (unsigned) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_int Function\n" + + "%load = OpLoad %int %n\n" + + "%2 = OpUDiv %long %long_4611686018427387904 %long_2\n" + + "OpReturn\n" + "OpFunctionEnd", + 2, 4611686018427387904 / 2), + // Test case 4: fold 4611686018427387904/2 (signed) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_int Function\n" + + "%load = OpLoad %int %n\n" + + "%2 = OpSDiv %long %long_4611686018427387904 %long_2\n" + + "OpReturn\n" + "OpFunctionEnd", + 2, 4611686018427387904 / 2), + // Test case 5: fold -4611686018427387904/2 (signed) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_int Function\n" + + "%load = OpLoad %int %n\n" + + "%2 = OpSDiv %long %long_n4611686018427387904 %long_2\n" + + "OpReturn\n" + "OpFunctionEnd", + 2, -4611686018427387904 / 2), + // Test case 6: fold 4611686018427387904 mod 7 (unsigned) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_int Function\n" + + "%load = OpLoad %int %n\n" + + "%2 = OpUMod %long %long_4611686018427387904 %long_7\n" + + "OpReturn\n" + "OpFunctionEnd", + 2, 4611686018427387904ull % 7ull), + // Test case 7: fold 7 mod 3 (signed) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_int Function\n" + + "%load = OpLoad %int %n\n" + + "%2 = OpSMod %long %long_7 %long_3\n" + "OpReturn\n" + + "OpFunctionEnd", + 2, 1ull), + // Test case 8: fold 7 rem 3 (signed) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_int Function\n" + + "%load = OpLoad %int %n\n" + + "%2 = OpSRem %long %long_7 %long_3\n" + "OpReturn\n" + + "OpFunctionEnd", + 2, 1ull), + // Test case 9: fold 7 mod -3 (signed) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_int Function\n" + + "%load = OpLoad %int %n\n" + + "%2 = OpSMod %long %long_7 %long_n3\n" + "OpReturn\n" + + "OpFunctionEnd", + 2, -2ll), + // Test case 10: fold 7 rem 3 (signed) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_int Function\n" + + "%load = OpLoad %int %n\n" + + "%2 = OpSRem %long %long_7 %long_n3\n" + "OpReturn\n" + + "OpFunctionEnd", + 2, 1ll), + // Test case 11: fold -7 mod 3 (signed) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_int Function\n" + + "%load = OpLoad %int %n\n" + + "%2 = OpSMod %long %long_n7 %long_3\n" + "OpReturn\n" + + "OpFunctionEnd", + 2, 2ll), + // Test case 12: fold -7 rem 3 (signed) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_int Function\n" + + "%load = OpLoad %int %n\n" + + "%2 = OpSRem %long %long_n7 %long_3\n" + "OpReturn\n" + + "OpFunctionEnd", + 2, -1ll))); + using UIntVectorInstructionFoldingTest = ::testing::TestWithParam>>; @@ -3832,23 +4012,7 @@ INSTANTIATE_TEST_SUITE_P(IntegerArithmeticTestCases, GeneralInstructionFoldingTe "OpReturn\n" + "OpFunctionEnd", 2, 0), - // Test case 38: Don't fold 2 + 3 (long), bad length - InstructionFoldingCase( - Header() + "%main = OpFunction %void None %void_func\n" + - "%main_lab = OpLabel\n" + - "%2 = OpIAdd %long %long_2 %long_3\n" + - "OpReturn\n" + - "OpFunctionEnd", - 2, 0), - // Test case 39: Don't fold 2 + 3 (short), bad length - InstructionFoldingCase( - Header() + "%main = OpFunction %void None %void_func\n" + - "%main_lab = OpLabel\n" + - "%2 = OpIAdd %short %short_2 %short_3\n" + - "OpReturn\n" + - "OpFunctionEnd", - 2, 0), - // Test case 40: fold 1*n + // Test case 38: fold 1*n InstructionFoldingCase( Header() + "%main = OpFunction %void None %void_func\n" + "%main_lab = OpLabel\n" + @@ -3858,7 +4022,7 @@ INSTANTIATE_TEST_SUITE_P(IntegerArithmeticTestCases, GeneralInstructionFoldingTe "OpReturn\n" + "OpFunctionEnd", 2, 3), - // Test case 41: fold n*1 + // Test case 39: fold n*1 InstructionFoldingCase( Header() + "%main = OpFunction %void None %void_func\n" + "%main_lab = OpLabel\n" + @@ -3868,7 +4032,7 @@ INSTANTIATE_TEST_SUITE_P(IntegerArithmeticTestCases, GeneralInstructionFoldingTe "OpReturn\n" + "OpFunctionEnd", 2, 3), - // Test case 42: Don't fold comparisons of 64-bit types + // Test case 40: Don't fold comparisons of 64-bit types // (https://github.com/KhronosGroup/SPIRV-Tools/issues/3343). InstructionFoldingCase( Header() + "%main = OpFunction %void None %void_func\n" +