Skip to content

Commit

Permalink
Fold 64-bit int operations.
Browse files Browse the repository at this point in the history
  • Loading branch information
s-perron committed Feb 6, 2024
1 parent bc10749 commit 691efa6
Show file tree
Hide file tree
Showing 2 changed files with 335 additions and 21 deletions.
152 changes: 151 additions & 1 deletion source/opt/const_folding_rules.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -676,7 +676,6 @@ ConstantFoldingRule FoldUnaryOp(UnaryScalarFoldingRule scalar_rule) {
return [scalar_rule](IRContext* context, Instruction* inst,
const std::vector<const analysis::Constant*>& constants)
-> const analysis::Constant* {

analysis::ConstantManager* const_mgr = context->get_constant_mgr();
analysis::TypeManager* type_mgr = context->get_type_mgr();
const analysis::Type* result_type = type_mgr->GetType(inst->type_id());
Expand Down Expand Up @@ -716,6 +715,63 @@ ConstantFoldingRule FoldUnaryOp(UnaryScalarFoldingRule scalar_rule) {
};
}

// Returns a |ConstantFoldingRule| that folds binary scalar ops
// using |scalar_rule| and unary vectors ops by applying
// |scalar_rule| to the elements of the vector. The |ConstantFoldingRule|
// that is returned assumes that |constants| contains 2 entries. If they are
// not |nullptr|, then their type is either |Float| or |Integer| or a |Vector|
// whose element type is |Float| or |Integer|.
ConstantFoldingRule FoldBinaryOp(BinaryScalarFoldingRule scalar_rule) {
return [scalar_rule](IRContext* context, Instruction* inst,
const std::vector<const analysis::Constant*>& constants)
-> const analysis::Constant* {
analysis::ConstantManager* const_mgr = context->get_constant_mgr();
analysis::TypeManager* type_mgr = context->get_type_mgr();
const analysis::Type* result_type = type_mgr->GetType(inst->type_id());
const analysis::Vector* vector_type = result_type->AsVector();

const analysis::Constant* arg1 =
(inst->opcode() == spv::Op::OpExtInst) ? constants[1] : constants[0];
const analysis::Constant* arg2 =
(inst->opcode() == spv::Op::OpExtInst) ? constants[2] : constants[1];

if (arg1 == nullptr) {
return nullptr;
}
if (arg2 == nullptr) {
return nullptr;
}

if (vector_type != nullptr) {
std::vector<const analysis::Constant*> a_components;
std::vector<const analysis::Constant*> b_components;
std::vector<const analysis::Constant*> results_components;

a_components = arg1->GetVectorComponents(const_mgr);
b_components = arg2->GetVectorComponents(const_mgr);

// Fold each component of the vector.
for (uint32_t i = 0; i < a_components.size(); ++i) {
results_components.push_back(scalar_rule(vector_type->element_type(),
a_components[i],
b_components[i], const_mgr));
if (results_components[i] == nullptr) {
return nullptr;
}
}

// Build the constant object and return it.
std::vector<uint32_t> ids;
for (const analysis::Constant* member : results_components) {
ids.push_back(const_mgr->GetDefiningInstruction(member)->result_id());
}
return const_mgr->GetConstant(vector_type, ids);
} else {
return scalar_rule(result_type, arg1, arg2, const_mgr);
}
};
}

// Returns a |ConstantFoldingRule| that folds unary floating point scalar ops
// using |scalar_rule| and unary float point vectors ops by applying
// |scalar_rule| to the elements of the vector. The |ConstantFoldingRule|
Expand Down Expand Up @@ -1587,6 +1643,60 @@ BinaryScalarFoldingRule FoldFTranscendentalBinary(double (*fp)(double,
return nullptr;
};
}

template <bool IsSigned>
BinaryScalarFoldingRule FoldBinaryUnsignedIntegerOperation(
uint64_t (*op)(uint64_t, uint64_t)) {
return
[op](const analysis::Type* result_type, const analysis::Constant* a,
const analysis::Constant* b,
analysis::ConstantManager* const_mgr) -> const analysis::Constant* {
assert(result_type != nullptr && a != nullptr);
const analysis::Integer* integer_type = a->type()->AsInteger();
assert(integer_type != nullptr);
assert(integer_type == result_type->AsInteger());
assert(integer_type == b->type()->AsInteger());

// In SPIR-V, the signedness of the operands is determined by the
// opcode, and not by the type of the operands. This is why we use the
// template argument to determine how to interpret the operands.
uint64_t ia =
(IsSigned ? a->GetSignExtendedValue() : a->GetZeroExtendedValue());
uint64_t ib =
(IsSigned ? b->GetSignExtendedValue() : b->GetZeroExtendedValue());
uint64_t result = op(ia, ib);

std::vector<uint32_t> words;
if (integer_type->width() == 64) {
// In the 64-bit case, two words are needed to represent the value.
words = {static_cast<uint32_t>(result),
static_cast<uint32_t>(result >> 32)};
} else {
// In all other cases, only a single word is needed.
assert(integer_type->width() <= 32);

uint64_t mask_for_last_bit = 1ull << (integer_type->width() - 1);
uint64_t significant_bits = (mask_for_last_bit << 1) - 1ull;
if (integer_type->IsSigned()) {
// Sign-extend the most significant bit.
if (result & mask_for_last_bit) {
// Set upper bits to 1
result |= ~significant_bits;
} else {
// Clear the upper bits
result &= significant_bits;
}
} else {
// Mask any bits largest than the width.
result &= (1ull << integer_type->width()) - 1ull;
}

words = {static_cast<uint32_t>(result)};
}
return const_mgr->GetConstant(result_type, words);
};
}

} // namespace

void ConstantFoldingRules::AddFoldingRules() {
Expand Down Expand Up @@ -1662,6 +1772,46 @@ void ConstantFoldingRules::AddFoldingRules() {
rules_[spv::Op::OpSNegate].push_back(FoldSNegate());
rules_[spv::Op::OpQuantizeToF16].push_back(FoldQuantizeToF16());

rules_[spv::Op::OpIAdd].push_back(
FoldBinaryOp(FoldBinaryUnsignedIntegerOperation<false>(
[](uint64_t a, uint64_t b) { return a + b; })));
rules_[spv::Op::OpISub].push_back(
FoldBinaryOp(FoldBinaryUnsignedIntegerOperation<false>(
[](uint64_t a, uint64_t b) { return a - b; })));
rules_[spv::Op::OpIMul].push_back(
FoldBinaryOp(FoldBinaryUnsignedIntegerOperation<false>(
[](uint64_t a, uint64_t b) { return a * b; })));
rules_[spv::Op::OpUDiv].push_back(
FoldBinaryOp(FoldBinaryUnsignedIntegerOperation<false>(
[](uint64_t a, uint64_t b) { return (b != 0 ? a / b : 0); })));
rules_[spv::Op::OpSDiv].push_back(FoldBinaryOp(
FoldBinaryUnsignedIntegerOperation<true>([](uint64_t a, uint64_t b) {
return (b != 0 ? static_cast<uint64_t>(static_cast<int64_t>(a) /
static_cast<int64_t>(b))
: 0);
})));
rules_[spv::Op::OpUMod].push_back(
FoldBinaryOp(FoldBinaryUnsignedIntegerOperation<false>(
[](uint64_t a, uint64_t b) { return (b != 0 ? a % b : 0); })));

rules_[spv::Op::OpSRem].push_back(FoldBinaryOp(
FoldBinaryUnsignedIntegerOperation<true>([](uint64_t a, uint64_t b) {
return (b != 0 ? static_cast<uint64_t>(static_cast<int64_t>(a) %
static_cast<int64_t>(b))
: 0);
})));

rules_[spv::Op::OpSMod].push_back(FoldBinaryOp(
FoldBinaryUnsignedIntegerOperation<true>([](uint64_t a, uint64_t b) {
if (b == 0) return static_cast<uint64_t>(0ull);

int64_t signed_a = static_cast<int64_t>(a);
int64_t signed_b = static_cast<int64_t>(b);
int64_t result = signed_a % signed_b;
if ((signed_b < 0) != (result < 0)) result += signed_b;
return static_cast<uint64_t>(result);
})));

// Add rules for GLSLstd450
FeatureManager* feature_manager = context_->get_feature_mgr();
uint32_t ext_inst_glslstd450_id =
Expand Down
Loading

0 comments on commit 691efa6

Please sign in to comment.