diff --git a/ir/attrs.cpp b/ir/attrs.cpp index ee1cdd5e7..d9e1184a8 100644 --- a/ir/attrs.cpp +++ b/ir/attrs.cpp @@ -148,6 +148,10 @@ ostream& operator<<(ostream &os, const FnAttrs &attr) { os << ", " << attr.allocsize_1; os << ')'; } + if (attr.vscaleRange) { + auto [low, high] = *attr.vscaleRange; + os << " vscale_range(" << low << ", " << high << ')'; + } attr.fp_denormal.print(os); if (attr.fp_denormal32) diff --git a/ir/attrs.h b/ir/attrs.h index dbb56826b..f451cfd84 100644 --- a/ir/attrs.h +++ b/ir/attrs.h @@ -137,6 +137,8 @@ class FnAttrs final { AllocSize = 1 << 12, ZeroExt = 1<<13, SignExt = 1<<14, NoFPClass = 1<<15, Asm = 1<<16 }; + std::optional> vscaleRange; + FnAttrs(unsigned bits = None) : bits(bits) {} bool has(Attribute a) const { return (bits & a) != 0; } diff --git a/ir/constant.cpp b/ir/constant.cpp index 5a3dee67c..f72698384 100644 --- a/ir/constant.cpp +++ b/ir/constant.cpp @@ -35,12 +35,12 @@ StateValue IntConst::toSMT(State &s) const { return { expr::mkInt(get(val).c_str(), bits()), true }; } -expr IntConst::getTypeConstraints() const { +expr IntConst::getTypeConstraints(const Function &f) const { unsigned min_bits = 0; if (auto v = get_if(&val)) min_bits = (*v >= 0 ? 63 : 64) - num_sign_bits(*v); - return Value::getTypeConstraints() && + return Value::getTypeConstraints(f) && getType().enforceIntType() && getType().sizeVar().uge(min_bits); } @@ -86,8 +86,8 @@ FloatConst::FloatConst(Type &type, string val, bool bit_value) : Constant(type, bit_value ? int_to_readable_float(type, val) : val), val(std::move(val)), bit_value(bit_value) {} -expr FloatConst::getTypeConstraints() const { - return Value::getTypeConstraints() && +expr FloatConst::getTypeConstraints(const Function &f) const { + return Value::getTypeConstraints(f) && getType().enforceFloatType(); } @@ -104,12 +104,12 @@ StateValue FloatConst::toSMT(State &s) const { StateValue ConstantInput::toSMT(State &s) const { - auto type = getType().getDummyValue(false).value; + auto type = getType().getDummyValue(false, s.getVscale()).value; return { expr::mkVar(getName().c_str(), type), true }; } -expr ConstantInput::getTypeConstraints() const { - return Value::getTypeConstraints() && +expr ConstantInput::getTypeConstraints(const Function &f) const { + return Value::getTypeConstraints(f) && (getType().enforceIntType() || getType().enforceFloatType()); } @@ -157,8 +157,8 @@ StateValue ConstantBinOp::toSMT(State &s) const { return { std::move(val), ap && bp }; } -expr ConstantBinOp::getTypeConstraints() const { - return Value::getTypeConstraints() && +expr ConstantBinOp::getTypeConstraints(const Function &f) const { + return Value::getTypeConstraints(f) && getType().enforceIntType() && getType() == lhs.getType() && getType() == rhs.getType(); @@ -210,10 +210,10 @@ StateValue ConstantFn::toSMT(State &s) const { return { std::move(r), true }; } -expr ConstantFn::getTypeConstraints() const { - expr r = Value::getTypeConstraints(); +expr ConstantFn::getTypeConstraints(const Function &f) const { + expr r = Value::getTypeConstraints(f); for (auto a : args) { - r &= a->getTypeConstraints(); + r &= a->getTypeConstraints(f); } Type &ty = getType(); diff --git a/ir/constant.h b/ir/constant.h index b7a2f014c..94e562fd7 100644 --- a/ir/constant.h +++ b/ir/constant.h @@ -26,7 +26,7 @@ class IntConst final : public Constant { IntConst(Type &type, int64_t val); IntConst(Type &type, std::string &&val); StateValue toSMT(State &s) const override; - smt::expr getTypeConstraints() const override; + smt::expr getTypeConstraints(const Function &f) const override; auto getInt() const { return std::get_if(&val); } }; @@ -38,7 +38,7 @@ class FloatConst final : public Constant { FloatConst(Type &type, std::string val, bool bit_value); StateValue toSMT(State &s) const override; - smt::expr getTypeConstraints() const override; + smt::expr getTypeConstraints(const Function &f) const override; }; @@ -47,7 +47,7 @@ class ConstantInput final : public Constant { ConstantInput(Type &type, std::string &&name) : Constant(type, std::move(name)) {} StateValue toSMT(State &s) const override; - smt::expr getTypeConstraints() const override; + smt::expr getTypeConstraints(const Function &f) const override; }; @@ -62,7 +62,7 @@ class ConstantBinOp final : public Constant { public: ConstantBinOp(Type &type, Constant &lhs, Constant &rhs, Op op); StateValue toSMT(State &s) const override; - smt::expr getTypeConstraints() const override; + smt::expr getTypeConstraints(const Function &f) const override; }; @@ -73,7 +73,7 @@ class ConstantFn final : public Constant { public: ConstantFn(Type &type, std::string_view name, std::vector &&args); StateValue toSMT(State &s) const override; - smt::expr getTypeConstraints() const override; + smt::expr getTypeConstraints(const Function &f) const override; }; struct ConstantFnException { diff --git a/ir/function.cpp b/ir/function.cpp index b5a55990f..a6402a75b 100644 --- a/ir/function.cpp +++ b/ir/function.cpp @@ -165,7 +165,7 @@ expr Function::getTypeConstraints() const { } for (auto &l : { getConstants(), getInputs(), getUndefs() }) { for (auto &v : l) { - t &= v.getTypeConstraints(); + t &= v.getTypeConstraints(*this); } } return t; @@ -448,7 +448,7 @@ static vector top_sort(const vector &bbs) { // in order to account for some transitive dependencies we may have // missed due to compression of its inner loops. // If there are no inner loops, this is redundant and if `bb` is not - // a loop header, the set of its exit blocks is empty. + // a loop header, the set of its exit blocks is empty. for (auto &dst : bb->getExitBlocks()) { auto dst_I = bb_map.find(dst); if (dst_I != bb_map.end()) @@ -807,8 +807,8 @@ void Function::unroll(unsigned k) { static PtrType ptr_type(0); static IntType i32(string("i32"), 32); auto &type = val->getType(); - auto size_alloc - = make_unique(i32, Memory::getStoreByteSize(type)); + auto size_alloc = make_unique( + i32, Memory::getStoreByteSize(type, expr::mkVscaleMin())); auto *size = size_alloc.get(); addConstant(std::move(size_alloc)); @@ -1017,7 +1017,7 @@ void DomTree::buildDominators(const CFG &cfg) { auto &entry = doms.at(&f.getFirstBB()); entry.dominator = &entry; - // Cooper, Keith D.; Harvey, Timothy J.; and Kennedy, Ken (2001). + // Cooper, Keith D.; Harvey, Timothy J.; and Kennedy, Ken (2001). // A Simple, Fast Dominance Algorithm // http://www.cs.rice.edu/~keith/EMBED/dom.pdf // Makes multiple passes when CFG is cyclic to update incorrect initial @@ -1220,5 +1220,5 @@ void LoopAnalysis::printDot(ostream &os) const { os << "}\n"; } -} +} diff --git a/ir/instr.cpp b/ir/instr.cpp index 2c9b97aa0..f1b40d0ea 100644 --- a/ir/instr.cpp +++ b/ir/instr.cpp @@ -102,7 +102,7 @@ uint64_t getGlobalVarSize(const IR::Value *V) { namespace IR { -expr Instr::getTypeConstraints() const { +expr Instr::getTypeConstraints(const Function &f) const { UNREACHABLE(); return {}; } @@ -469,7 +469,7 @@ StateValue BinOp::toSMT(State &s) const { auto &ty = getType(); uint32_t resBits = (ty.isVectorType() ? ty.getAsAggregateType()->getChild(0) : ty) - .bits(); + .bits(s.getVscale()); return {expr::mkIf(a == b, expr::mkUInt(0, resBits), expr::mkIf(op == UCmp ? a.ult(b) : a.slt(b), expr::mkInt(-1, resBits), @@ -510,21 +510,23 @@ StateValue BinOp::toSMT(State &s) const { auto val1ty = retty->getChild(0).getAsAggregateType(); auto val2ty = retty->getChild(val2idx).getAsAggregateType(); - for (unsigned i = 0, e = ty->numElementsConst(); i != e; ++i) { - auto ai = ty->extract(a, i); - auto bi = ty->extract(b, i); + for (unsigned i = 0, e = ty->numElementsConst(s.getVscale()); i != e; + ++i) { + auto ai = ty->extract(a, i, s.getVscale()); + auto bi = ty->extract(b, i, s.getVscale()); auto [v1, v2] = zip_op(ai.value, ai.non_poison, bi.value, bi.non_poison); vals1.emplace_back(std::move(v1)); vals2.emplace_back(std::move(v2)); } - vals.emplace_back(val1ty->aggregateVals(vals1)); - vals.emplace_back(val2ty->aggregateVals(vals2)); + vals.emplace_back(val1ty->aggregateVals(vals1, s.getVscale())); + vals.emplace_back(val2ty->aggregateVals(vals2, s.getVscale())); } else { StateValue tmp; auto opty = lhs->getType().getAsAggregateType(); - for (unsigned i = 0, e = opty->numElementsConst(); i != e; ++i) { - auto ai = opty->extract(a, i); + for (unsigned i = 0, e = opty->numElementsConst(s.getVscale()); i != e; + ++i) { + auto ai = opty->extract(a, i, s.getVscale()); const StateValue *bi; switch (op) { case Abs: @@ -533,7 +535,7 @@ StateValue BinOp::toSMT(State &s) const { bi = &b; break; default: - tmp = opty->extract(b, i); + tmp = opty->extract(b, i, s.getVscale()); bi = &tmp; break; } @@ -541,7 +543,7 @@ StateValue BinOp::toSMT(State &s) const { bi->non_poison)); } } - return retty->aggregateVals(vals); + return retty->aggregateVals(vals, s.getVscale()); } if (vertical_zip) { @@ -549,7 +551,7 @@ StateValue BinOp::toSMT(State &s) const { auto [v1, v2] = zip_op(a.value, a.non_poison, b.value, b.non_poison); vals.emplace_back(std::move(v1)); vals.emplace_back(std::move(v2)); - return getType().getAsAggregateType()->aggregateVals(vals); + return getType().getAsAggregateType()->aggregateVals(vals, s.getVscale()); } return scalar_op(a.value, a.non_poison, b.value, b.non_poison); } @@ -596,7 +598,7 @@ expr BinOp::getTypeConstraints(const Function &f) const { getType() == rhs->getType(); break; } - return Value::getTypeConstraints() && std::move(instrconstr); + return Value::getTypeConstraints(f) && std::move(instrconstr); } unique_ptr BinOp::dup(Function &f, const string &suffix) const { @@ -948,17 +950,19 @@ StateValue FpBinOp::toSMT(State &s) const { if (lhs->getType().isVectorType()) { auto retty = getType().getAsAggregateType(); vector vals; - for (unsigned i = 0, e = retty->numElementsConst(); i != e; ++i) { - vals.emplace_back(scalar(retty->extract(a, i), retty->extract(b, i), + for (unsigned i = 0, e = retty->numElementsConst(s.getVscale()); i != e; + ++i) { + vals.emplace_back(scalar(retty->extract(a, i, s.getVscale()), + retty->extract(b, i, s.getVscale()), retty->getChild(i))); } - return retty->aggregateVals(vals); + return retty->aggregateVals(vals, s.getVscale()); } return scalar(a, b, getType()); } expr FpBinOp::getTypeConstraints(const Function &f) const { - return Value::getTypeConstraints() && + return Value::getTypeConstraints(f) && getType().enforceFloatOrVectorType() && getType() == lhs->getType() && getType() == rhs->getType(); @@ -1056,11 +1060,11 @@ StateValue UnaryOp::toSMT(State &s) const { if (getType().isVectorType()) { vector vals; auto ty = val->getType().getAsAggregateType(); - for (unsigned i = 0, e = ty->numElementsConst(); i != e; ++i) { - auto vi = ty->extract(v, i); + for (unsigned i = 0, e = ty->numElementsConst(s.getVscale()); i != e; ++i) { + auto vi = ty->extract(v, i, s.getVscale()); vals.emplace_back(fn(vi.value, vi.non_poison)); } - return getType().getAsAggregateType()->aggregateVals(vals); + return getType().getAsAggregateType()->aggregateVals(vals, s.getVscale()); } return fn(v.value, v.non_poison); } @@ -1086,7 +1090,7 @@ expr UnaryOp::getTypeConstraints(const Function &f) const { break; } - return Value::getTypeConstraints() && std::move(instrconstr); + return Value::getTypeConstraints(f) && std::move(instrconstr); } static Value* dup_aggregate(Function &f, Value *val) { @@ -1095,8 +1099,8 @@ static Value* dup_aggregate(Function &f, Value *val) { for (auto v : agg->getVals()) { elems.emplace_back(dup_aggregate(f, v)); } - auto agg_new - = make_unique(agg->getType(), std::move(elems)); + auto agg_new = + make_unique(agg->getType(), std::move(elems)); auto ret = agg_new.get(); f.addAggregate(std::move(agg_new)); return ret; @@ -1204,16 +1208,17 @@ StateValue FpUnaryOp::toSMT(State &s) const { if (getType().isVectorType()) { vector vals; auto ty = val->getType().getAsAggregateType(); - for (unsigned i = 0, e = ty->numElementsConst(); i != e; ++i) { - vals.emplace_back(scalar(ty->extract(v, i), ty->getChild(i))); + for (unsigned i = 0, e = ty->numElementsConst(s.getVscale()); i != e; ++i) { + vals.emplace_back( + scalar(ty->extract(v, i, s.getVscale()), ty->getChild(i))); } - return getType().getAsAggregateType()->aggregateVals(vals); + return getType().getAsAggregateType()->aggregateVals(vals, s.getVscale()); } return scalar(v, getType()); } expr FpUnaryOp::getTypeConstraints(const Function &f) const { - return Value::getTypeConstraints() && + return Value::getTypeConstraints(f) && getType() == val->getType() && getType().enforceFloatOrVectorType(); } @@ -1262,8 +1267,8 @@ StateValue UnaryReductionOp::toSMT(State &s) const { auto &v = s[*val]; auto vty = val->getType().getAsAggregateType(); StateValue res; - for (unsigned i = 0, e = vty->numElementsConst(); i != e; ++i) { - auto ith = vty->extract(v, i); + for (unsigned i = 0, e = vty->numElementsConst(s.getVscale()); i != e; ++i) { + auto ith = vty->extract(v, i, s.getVscale()); if (i == 0) { res = std::move(ith); continue; @@ -1286,7 +1291,7 @@ StateValue UnaryReductionOp::toSMT(State &s) const { } expr UnaryReductionOp::getTypeConstraints(const Function &f) const { - return Value::getTypeConstraints() && + return Value::getTypeConstraints(f) && getType().enforceIntType() && val->getType().enforceVectorType( [this](auto &scalar) { return scalar == getType(); }); @@ -1370,12 +1375,12 @@ StateValue TernaryOp::toSMT(State &s) const { vector vals; auto ty = getType().getAsAggregateType(); - for (unsigned i = 0, e = ty->numElementsConst(); i != e; ++i) { - vals.emplace_back(scalar(ty->extract(av, i), ty->extract(bv, i), - (op == FShl || op == FShr) ? - ty->extract(cv, i) : cv)); + for (unsigned i = 0, e = ty->numElementsConst(s.getVscale()); i != e; ++i) { + vals.emplace_back(scalar( + ty->extract(av, i, s.getVscale()), ty->extract(bv, i, s.getVscale()), + (op == FShl || op == FShr) ? ty->extract(cv, i, s.getVscale()) : cv)); } - return ty->aggregateVals(vals); + return ty->aggregateVals(vals, s.getVscale()); } return scalar(av, bv, cv); } @@ -1405,7 +1410,7 @@ expr TernaryOp::getTypeConstraints(const Function &f) const { getType().enforceIntOrVectorType(); break; } - return Value::getTypeConstraints() && instrconstr; + return Value::getTypeConstraints(f) && instrconstr; } unique_ptr TernaryOp::dup(Function &f, const string &suffix) const { @@ -1476,17 +1481,18 @@ StateValue FpTernaryOp::toSMT(State &s) const { vector vals; auto ty = getType().getAsAggregateType(); - for (unsigned i = 0, e = ty->numElementsConst(); i != e; ++i) { - vals.emplace_back(scalar(ty->extract(av, i), ty->extract(bv, i), - ty->extract(cv, i), ty->getChild(i))); + for (unsigned i = 0, e = ty->numElementsConst(s.getVscale()); i != e; ++i) { + vals.emplace_back(scalar( + ty->extract(av, i, s.getVscale()), ty->extract(bv, i, s.getVscale()), + ty->extract(cv, i, s.getVscale()), ty->getChild(i))); } - return ty->aggregateVals(vals); + return ty->aggregateVals(vals, s.getVscale()); } return scalar(av, bv, cv, getType()); } expr FpTernaryOp::getTypeConstraints(const Function &f) const { - return Value::getTypeConstraints() && + return Value::getTypeConstraints(f) && getType() == a->getType() && getType() == b->getType() && getType() == c->getType() && @@ -1548,16 +1554,17 @@ StateValue TestOp::toSMT(State &s) const { vector vals; auto ty = lhs->getType().getAsAggregateType(); - for (unsigned i = 0, e = ty->numElementsConst(); i != e; ++i) { - vals.emplace_back(scalar(ty->extract(a, i), ty->getChild(i))); + for (unsigned i = 0, e = ty->numElementsConst(s.getVscale()); i != e; ++i) { + vals.emplace_back( + scalar(ty->extract(a, i, s.getVscale()), ty->getChild(i))); } - return getType().getAsAggregateType()->aggregateVals(vals); + return getType().getAsAggregateType()->aggregateVals(vals, s.getVscale()); } return scalar(a, lhs->getType()); } expr TestOp::getTypeConstraints(const Function &f) const { - return Value::getTypeConstraints() && + return Value::getTypeConstraints(f) && lhs->getType().enforceFloatOrVectorType() && rhs->getType().enforceIntType(32) && getType().enforceIntOrVectorType(1) && @@ -1627,21 +1634,21 @@ StateValue ConversionOp::toSMT(State &s) const { switch (op) { case SExt: - fn = [](auto &&val, auto &to_type) -> StateValue { - return {val.sext(to_type.bits() - val.bits()), true}; + fn = [&](auto &&val, auto &to_type) -> StateValue { + return {val.sext(to_type.bits(s.getVscale()) - val.bits()), true}; }; break; case ZExt: fn = [&](auto &&val, auto &to_type) -> StateValue { - return { val.zext(to_type.bits() - val.bits()), + return { val.zext(to_type.bits(s.getVscale()) - val.bits()), (flags & NNEG) ? !val.isNegative() : true }; }; break; case Trunc: - fn = [this](auto &&val, auto &to_type) -> StateValue { + fn = [&](auto &&val, auto &to_type) -> StateValue { AndExpr non_poison; unsigned orig_bits = val.bits(); - unsigned trunc_bits = to_type.bits(); + unsigned trunc_bits = to_type.bits(s.getVscale()); expr val_truncated = val.trunc(trunc_bits); if (flags & NUW) non_poison.add(val.extract(orig_bits-1, trunc_bits) == 0); @@ -1656,11 +1663,14 @@ StateValue ConversionOp::toSMT(State &s) const { getType().getAsAggregateType()->getChild(0).isPtrType()) return v; - return getType().fromInt(val->getType().toInt(s, std::move(v))); + return getType().fromInt(val->getType().toInt(s, std::move(v)), + s.getVscale()); case Ptr2Int: fn = [&](auto &&val, auto &to_type) -> StateValue { - return {s.getMemory().ptr2int(val).zextOrTrunc(to_type.bits()), true}; + return { + s.getMemory().ptr2int(val).zextOrTrunc(to_type.bits(s.getVscale())), + true}; }; break; case Int2Ptr: @@ -1678,13 +1688,14 @@ StateValue ConversionOp::toSMT(State &s) const { if (getType().isVectorType()) { vector vals; auto retty = getType().getAsAggregateType(); - auto elems = retty->numElementsConst(); + auto elems = retty->numElementsConst(s.getVscale()); auto valty = val->getType().getAsAggregateType(); for (unsigned i = 0; i != elems; ++i) { - vals.emplace_back(scalar(valty->extract(v, i), retty->getChild(i))); + vals.emplace_back( + scalar(valty->extract(v, i, s.getVscale()), retty->getChild(i))); } - return retty->aggregateVals(vals); + return retty->aggregateVals(vals, s.getVscale()); } return scalar(std::move(v), getType()); @@ -1721,7 +1732,7 @@ expr ConversionOp::getTypeConstraints(const Function &f) const { break; } - c &= Value::getTypeConstraints(); + c &= Value::getTypeConstraints(f); if (op != BitCast) c &= getType().enforceVectorTypeEquiv(val->getType()); return c; @@ -1837,7 +1848,7 @@ StateValue FpConversionOp::toSMT(State &s) const { break; default: UNREACHABLE(); } - auto bits = to_type.bits(); + auto bits = to_type.bits(s.getVscale()); expr bv = val.fp2sint(bits, rm); expr fp2 = bv.sint2fp(val, rm); // -0.xx is converted to 0 and then to 0.0, though -0.xx is ok to convert @@ -1867,7 +1878,7 @@ StateValue FpConversionOp::toSMT(State &s) const { case FPToUInt: case FPToUInt_Sat: fn = [&](auto &val, auto &to_type, auto &rm_) -> StateValue { - auto bits = to_type.bits(); + auto bits = to_type.bits(s.getVscale()); expr rm = expr::rtz(); expr bv = val.fp2uint(bits, rm); expr fp2 = bv.uint2fp(val, rm); @@ -1928,11 +1939,11 @@ StateValue FpConversionOp::toSMT(State &s) const { auto ty = val->getType().getAsAggregateType(); auto retty = getType().getAsAggregateType(); - for (unsigned i = 0, e = ty->numElementsConst(); i != e; ++i) { - vals.emplace_back(scalar(ty->extract(v, i), ty->getChild(i), - retty->getChild(i))); + for (unsigned i = 0, e = ty->numElementsConst(s.getVscale()); i != e; ++i) { + vals.emplace_back(scalar(ty->extract(v, i, s.getVscale()), + ty->getChild(i), retty->getChild(i))); } - return retty->aggregateVals(vals); + return retty->aggregateVals(vals, s.getVscale()); } return scalar(v, val->getType(), getType()); } @@ -1965,7 +1976,7 @@ expr FpConversionOp::getTypeConstraints(const Function &f) const { val->getType().scalarSize().ugt(getType().scalarSize()); break; } - return Value::getTypeConstraints() && c; + return Value::getTypeConstraints(f) && c; } unique_ptr FpConversionOp::dup(Function &f, const string &suffix) const { @@ -2015,19 +2026,22 @@ StateValue Select::toSMT(State &s) const { vector vals; auto cond_agg = cond->getType().getAsAggregateType(); - for (unsigned i = 0, e = agg->numElementsConst(); i != e; ++i) { + for (unsigned i = 0, e = agg->numElementsConst(s.getVscale()); i != e; + ++i) { if (!agg->isPadding(i)) - vals.emplace_back(scalar(agg->extract(av, i), agg->extract(bv, i), - cond_agg ? cond_agg->extract(cv, i) : cv, - agg->getChild(i))); + vals.emplace_back( + scalar(agg->extract(av, i, s.getVscale()), + agg->extract(bv, i, s.getVscale()), + cond_agg ? cond_agg->extract(cv, i, s.getVscale()) : cv, + agg->getChild(i))); } - return agg->aggregateVals(vals); + return agg->aggregateVals(vals, s.getVscale()); } return scalar(av, bv, cv, getType()); } expr Select::getTypeConstraints(const Function &f) const { - return Value::getTypeConstraints() && + return Value::getTypeConstraints(f) && cond->getType().enforceIntOrVectorType(1) && getType().enforceVectorTypeIff(cond->getType()) && (fmath.isNone() ? expr(true) : getType().enforceFloatOrVectorType()) && @@ -2073,15 +2087,16 @@ StateValue ExtractValue::toSMT(State &s) const { Type *type = &val->getType(); for (auto idx : idxs) { auto ty = type->getAsAggregateType(); - v = ty->extract(v, idx); + v = ty->extract(v, idx, s.getVscale()); type = &ty->getChild(idx); } return v; } expr ExtractValue::getTypeConstraints(const Function &f) const { - auto c = Value::getTypeConstraints() && + auto c = Value::getTypeConstraints(f) && val->getType().enforceAggregateType(); + expr vscaleRange = State::vscaleFromAttr(f.getFnAttrs().vscaleRange); Type *type = &val->getType(); unsigned i = 0; @@ -2093,8 +2108,8 @@ expr ExtractValue::getTypeConstraints(const Function &f) const { } type = &ty->getChild(idx); - c &= ty->numElements().ugt(idx); - if (++i == idxs.size() && idx < ty->numElementsConst()) + c &= ty->numElements(vscaleRange).ugt(idx); + if (++i == idxs.size() && idx < ty->numElementsConst(vscaleRange)) c &= ty->getChild(idx) == getType(); } return c; @@ -2137,7 +2152,8 @@ void InsertValue::print(ostream &os) const { } } -static StateValue update_repack(Type *type, +static StateValue update_repack(const State &s, + Type *type, const StateValue &val, const StateValue &elem, vector &indices) { @@ -2145,21 +2161,21 @@ static StateValue update_repack(Type *type, unsigned cur_idx = indices.back(); indices.pop_back(); vector vals; - for (unsigned i = 0, e = ty->numElementsConst(); i < e; ++i) { + for (unsigned i = 0, e = ty->numElementsConst(s.getVscale()); i != e; ++i) { if (ty->isPadding(i)) continue; - auto v = ty->extract(val, i); + auto v = ty->extract(val, i, s.getVscale()); if (i == cur_idx) { - vals.emplace_back(indices.empty() ? - elem : - update_repack(&ty->getChild(i), v, elem, indices)); + vals.emplace_back(indices.empty() ? elem + : update_repack(s, &ty->getChild(i), v, + elem, indices)); } else { vals.emplace_back(std::move(v)); } } - return ty->aggregateVals(vals); + return ty->aggregateVals(vals, s.getVscale()); } StateValue InsertValue::toSMT(State &s) const { @@ -2168,13 +2184,14 @@ StateValue InsertValue::toSMT(State &s) const { Type *type = &val->getType(); vector idxs_reverse(idxs.rbegin(), idxs.rend()); - return update_repack(type, sv, elem, idxs_reverse); + return update_repack(s, type, sv, elem, idxs_reverse); } expr InsertValue::getTypeConstraints(const Function &f) const { - auto c = Value::getTypeConstraints() && + auto c = Value::getTypeConstraints(f) && val->getType().enforceAggregateType() && val->getType() == getType(); + expr vscaleRange = State::vscaleFromAttr(f.getFnAttrs().vscaleRange); Type *type = &val->getType(); unsigned i = 0; @@ -2185,8 +2202,8 @@ expr InsertValue::getTypeConstraints(const Function &f) const { type = &ty->getChild(idx); - c &= ty->numElements().ugt(idx); - if (++i == idxs.size() && idx < ty->numElementsConst()) + c &= ty->numElements(vscaleRange).ugt(idx); + if (++i == idxs.size() && idx < ty->numElementsConst(vscaleRange)) c &= ty->getChild(idx) == elt->getType(); } @@ -2421,11 +2438,14 @@ static void unpack_inputs(State &s, Value &argv, Type &ty, StateValue value2, vector &inputs, vector &ptr_inputs, unsigned idx) { if (auto agg = ty.getAsAggregateType()) { - for (unsigned i = 0, e = agg->numElementsConst(); i != e; ++i) { + for (unsigned i = 0, e = agg->numElementsConst(s.getVscale()); i != e; + ++i) { if (agg->isPadding(i)) continue; - unpack_inputs(s, argv, agg->getChild(i), argflag, agg->extract(value, i), - agg->extract(value2, i), inputs, ptr_inputs, idx); + unpack_inputs(s, argv, agg->getChild(i), argflag, + agg->extract(value, i, s.getVscale()), + agg->extract(value2, i, s.getVscale()), inputs, ptr_inputs, + idx); } return; } @@ -2462,12 +2482,14 @@ pack_return(State &s, Type &ty, StateValue &&val, const FnAttrs &attrs, const vector> &args) { if (auto agg = ty.getAsAggregateType()) { vector vs; - for (unsigned i = 0, e = agg->numElementsConst(); i != e; ++i) { + for (unsigned i = 0, e = agg->numElementsConst(s.getVscale()); i != e; + ++i) { if (!agg->isPadding(i)) - vs.emplace_back( - pack_return(s, agg->getChild(i), agg->extract(val, i), attrs, args)); + vs.emplace_back(pack_return(s, agg->getChild(i), + agg->extract(val, i, s.getVscale()), attrs, + args)); } - return agg->aggregateVals(vs); + return agg->aggregateVals(vs, s.getVscale()); } return check_return_value(s, std::move(val), ty, attrs, args); @@ -2646,7 +2668,7 @@ StateValue FnCall::toSMT(State &s) const { expr FnCall::getTypeConstraints(const Function &f) const { // TODO : also need to name each arg type smt var uniquely - expr ret = Value::getTypeConstraints(); + expr ret = Value::getTypeConstraints(f); if (fnptr) ret &= fnptr->getType().enforcePtrType(); return ret; @@ -2799,17 +2821,18 @@ StateValue ICmp::toSMT(State &s) const { auto &elem_ty = a->getType(); if (auto agg = elem_ty.getAsAggregateType()) { vector vals; - for (unsigned i = 0, e = agg->numElementsConst(); i != e; ++i) { - vals.emplace_back(scalar(agg->extract(a_eval, i), - agg->extract(b_eval, i))); + for (unsigned i = 0, e = agg->numElementsConst(s.getVscale()); i != e; + ++i) { + vals.emplace_back(scalar(agg->extract(a_eval, i, s.getVscale()), + agg->extract(b_eval, i, s.getVscale()))); } - return getType().getAsAggregateType()->aggregateVals(vals); + return getType().getAsAggregateType()->aggregateVals(vals, s.getVscale()); } return scalar(a_eval, b_eval); } expr ICmp::getTypeConstraints(const Function &f) const { - return Value::getTypeConstraints() && + return Value::getTypeConstraints(f) && getType().enforceIntOrVectorType(1) && getType().enforceVectorTypeEquiv(a->getType()) && a->getType().enforceIntOrPtrOrVectorType() && @@ -2898,17 +2921,19 @@ StateValue FCmp::toSMT(State &s) const { if (auto agg = a->getType().getAsAggregateType()) { vector vals; - for (unsigned i = 0, e = agg->numElementsConst(); i != e; ++i) { - vals.emplace_back(fn(agg->extract(a_eval, i), agg->extract(b_eval, i), + for (unsigned i = 0, e = agg->numElementsConst(s.getVscale()); i != e; + ++i) { + vals.emplace_back(fn(agg->extract(a_eval, i, s.getVscale()), + agg->extract(b_eval, i, s.getVscale()), agg->getChild(i))); } - return getType().getAsAggregateType()->aggregateVals(vals); + return getType().getAsAggregateType()->aggregateVals(vals, s.getVscale()); } return fn(a_eval, b_eval, a->getType()); } expr FCmp::getTypeConstraints(const Function &f) const { - return Value::getTypeConstraints() && + return Value::getTypeConstraints(f) && getType().enforceIntOrVectorType(1) && getType().enforceVectorTypeEquiv(a->getType()) && a->getType().enforceFloatOrVectorType() && @@ -2943,18 +2968,20 @@ void Freeze::print(ostream &os) const { static StateValue freeze_elems(State &s, const Type &ty, const StateValue &v) { if (auto agg = ty.getAsAggregateType()) { vector vals; - for (unsigned i = 0, e = agg->numElementsConst(); i != e; ++i) { + for (unsigned i = 0, e = agg->numElementsConst(s.getVscale()); i != e; + ++i) { if (agg->isPadding(i)) continue; - vals.emplace_back(freeze_elems(s, agg->getChild(i), agg->extract(v, i))); + vals.emplace_back( + freeze_elems(s, agg->getChild(i), agg->extract(v, i, s.getVscale()))); } - return agg->aggregateVals(vals); + return agg->aggregateVals(vals, s.getVscale()); } if (v.non_poison.isTrue()) return v; - StateValue ret_type = ty.getDummyValue(true); + StateValue ret_type = ty.getDummyValue(true, s.getVscale()); expr nondet = expr::mkFreshVar("nondet", ret_type.value); s.addQuantVar(nondet); return { expr::mkIf(v.non_poison, v.value, nondet), @@ -2968,7 +2995,7 @@ StateValue Freeze::toSMT(State &s) const { } expr Freeze::getTypeConstraints(const Function &f) const { - return Value::getTypeConstraints() && + return Value::getTypeConstraints(f) && getType() == val->getType(); } @@ -3060,7 +3087,7 @@ void Phi::print(ostream &os) const { } StateValue Phi::toSMT(State &s) const { - DisjointExpr ret(getType().getDummyValue(false)); + DisjointExpr ret(getType().getDummyValue(false, s.getVscale())); map cache; for (auto &[val, bb] : values) { @@ -3080,7 +3107,7 @@ StateValue Phi::toSMT(State &s) const { } expr Phi::getTypeConstraints(const Function &f) const { - auto c = Value::getTypeConstraints(); + auto c = Value::getTypeConstraints(f); for (auto &[val, bb] : values) { c &= val->getType() == getType(); } @@ -3281,14 +3308,16 @@ check_ret_attributes(State &s, StateValue &&sv, const StateValue &returned_arg, const vector> &args) { if (auto agg = t.getAsAggregateType()) { vector vals; - for (unsigned i = 0, e = agg->numElementsConst(); i != e; ++i) { + for (unsigned i = 0, e = agg->numElementsConst(s.getVscale()); i != e; + ++i) { if (agg->isPadding(i)) continue; - vals.emplace_back(check_ret_attributes(s, agg->extract(sv, i), - agg->extract(returned_arg, i), - agg->getChild(i), attrs, args)); + vals.emplace_back( + check_ret_attributes(s, agg->extract(sv, i, s.getVscale()), + agg->extract(returned_arg, i, s.getVscale()), + agg->getChild(i), attrs, args)); } - return agg->aggregateVals(vals); + return agg->aggregateVals(vals, s.getVscale()); } if (t.isPtrType()) { @@ -3324,7 +3353,7 @@ StateValue Return::toSMT(State &s) const { } expr Return::getTypeConstraints(const Function &f) const { - return Value::getTypeConstraints() && + return Value::getTypeConstraints(f) && getType() == val->getType() && f.getType() == getType(); } @@ -3565,11 +3594,12 @@ StateValue AssumeVal::toSMT(State &s) const { auto &v = s.getMaybeUB(*val, is_welldefined); if (auto agg = getType().getAsAggregateType()) { vector vals; - for (unsigned i = 0, e = agg->numElementsConst(); i != e; ++i) { - auto elem = agg->extract(v, i); + for (unsigned i = 0, e = agg->numElementsConst(s.getVscale()); i != e; + ++i) { + auto elem = agg->extract(v, i, s.getVscale()); vals.emplace_back(expr(elem.value), elem.non_poison && fn(elem.value)); } - return getType().getAsAggregateType()->aggregateVals(vals); + return getType().getAsAggregateType()->aggregateVals(vals, s.getVscale()); } expr np = fn(v.value); @@ -3711,7 +3741,7 @@ StateValue Alloc::toSMT(State &s) const { } expr Alloc::getTypeConstraints(const Function &f) const { - return Value::getTypeConstraints() && + return Value::getTypeConstraints(f) && getType().enforcePtrType() && size->getType().enforceIntType(); } @@ -3947,18 +3977,21 @@ StateValue GEP::toSMT(State &s) const { auto &ptrval = s[*ptr]; bool ptr_isvect = ptr->getType().isVectorType(); - for (unsigned i = 0, e = aty->numElementsConst(); i != e; ++i) { + for (unsigned i = 0, e = aty->numElementsConst(s.getVscale()); i != e; + ++i) { vector> offsets; for (auto &[sz, idx] : idxs) { if (auto idx_aty = idx->getType().getAsAggregateType()) - offsets.emplace_back(sz, idx_aty->extract(s[*idx], i)); + offsets.emplace_back(sz, idx_aty->extract(s[*idx], i, s.getVscale())); else offsets.emplace_back(sz, s[*idx]); } - vals.emplace_back(scalar(ptr_isvect ? aty->extract(ptrval, i) : - (i == 0 ? ptrval : s[*ptr]), offsets)); + vals.emplace_back(scalar(ptr_isvect + ? aty->extract(ptrval, i, s.getVscale()) + : (i == 0 ? ptrval : s[*ptr]), + offsets)); } - return getType().getAsAggregateType()->aggregateVals(vals); + return getType().getAsAggregateType()->aggregateVals(vals, s.getVscale()); } vector> offsets; for (auto &[sz, idx] : idxs) @@ -3967,7 +4000,7 @@ StateValue GEP::toSMT(State &s) const { } expr GEP::getTypeConstraints(const Function &f) const { - auto c = Value::getTypeConstraints() && + auto c = Value::getTypeConstraints(f) && getType().enforceVectorTypeIff(ptr->getType()) && getType().enforcePtrOrVectorType(); for (auto &[sz, idx] : idxs) { @@ -4042,17 +4075,18 @@ StateValue PtrMask::toSMT(State &s) const { auto maskTy = mask->getType().getAsAggregateType(); assert(maskTy); vector vals; - for (unsigned i = 0, e = agg->numElementsConst(); i != e; ++i) { - vals.emplace_back(fn(agg->extract(ptrval, i), - maskTy->extract(maskval, i))); + for (unsigned i = 0, e = agg->numElementsConst(s.getVscale()); i != e; + ++i) { + vals.emplace_back(fn(agg->extract(ptrval, i, s.getVscale()), + maskTy->extract(maskval, i, s.getVscale()))); } - return agg->aggregateVals(vals); + return agg->aggregateVals(vals, s.getVscale()); } return fn(ptrval, maskval); } expr PtrMask::getTypeConstraints(const Function &f) const { - return Value::getTypeConstraints() && + return Value::getTypeConstraints(f) && ptr->getType().enforcePtrOrVectorType() && getType() == ptr->getType() && mask->getType().enforceIntOrVectorType() && @@ -4068,7 +4102,8 @@ DEFINE_AS_RETZEROALIGN(Load, getMaxAllocSize) DEFINE_AS_RETZERO(Load, getMaxGEPOffset) uint64_t Load::getMaxAccessSize() const { - return round_up(Memory::getStoreByteSize(getType()), align); + return round_up(Memory::getStoreByteSize(getType(), expr::mkVscaleMin()), + align); } MemInstr::ByteAccessInfo Load::getByteAccessInfo() const { @@ -4101,7 +4136,7 @@ StateValue Load::toSMT(State &s) const { } expr Load::getTypeConstraints(const Function &f) const { - return Value::getTypeConstraints() && + return Value::getTypeConstraints(f) && ptr->getType().enforcePtrType(); } @@ -4114,7 +4149,8 @@ DEFINE_AS_RETZEROALIGN(Store, getMaxAllocSize) DEFINE_AS_RETZERO(Store, getMaxGEPOffset) uint64_t Store::getMaxAccessSize() const { - return round_up(Memory::getStoreByteSize(val->getType()), align); + return round_up(Memory::getStoreByteSize(val->getType(), expr::mkVscaleMin()), + align); } MemInstr::ByteAccessInfo Store::getByteAccessInfo() const { @@ -4142,7 +4178,9 @@ StateValue Store::toSMT(State &s) const { // skip large initializers. FIXME: this should be moved to memory so it can // fold subsequent trivial loads if (s.isInitializationPhase() && - Memory::getStoreByteSize(val->getType()) / (bits_byte / 8) > 128) { + Memory::getStoreByteSize(val->getType(), s.getVscale()) / + (bits_byte / 8) > + 128) { s.doesApproximation("Large constant initializer removed"); return {}; } @@ -4574,7 +4612,8 @@ StateValue Strlen::toSMT(State &s) const { ub.add(std::move(ub_load.first)); ub.add(std::move(ub_load.second)); ub.add(std::move(val.non_poison)); - return { expr::mkUInt(i, ty.bits()), true, std::move(ub), val.value != 0 }; + return {expr::mkUInt(i, ty.bits(s.getVscale())), true, std::move(ub), + val.value != 0}; }; auto [val, _, ub] = LoopLikeFunctionApproximator(ith_exec).encode(s, strlen_unroll_cnt); @@ -4583,7 +4622,7 @@ StateValue Strlen::toSMT(State &s) const { } expr Strlen::getTypeConstraints(const Function &f) const { - return Value::getTypeConstraints() && + return Value::getTypeConstraints(f) && ptr->getType().enforcePtrType() && getType().enforceIntType(); } @@ -4817,7 +4856,7 @@ StateValue VaArg::toSMT(State &s) const { ensure_varargs_ptr(data, s, raw_p); DisjointExpr ret(StateValue{}); - expr value_kind = getType().getDummyValue(false).value; + expr value_kind = getType().getDummyValue(false, s.getVscale()).value; expr one = expr::mkUInt(1, VARARG_BITS); for (auto &[ptr, entry] : data) { @@ -4880,13 +4919,13 @@ void ExtractElement::print(ostream &os) const { StateValue ExtractElement::toSMT(State &s) const { auto &[iv, ip] = s[*idx]; auto vty = static_cast(v->getType().getAsAggregateType()); - expr inbounds = iv.ult(vty->numElementsConst()); - auto [rv, rp] = vty->extract(s[*v], iv); + expr inbounds = iv.ult(vty->numElementsConst(s.getVscale())); + auto [rv, rp] = vty->extract(s[*v], iv, s.getVscale()); return { std::move(rv), ip && inbounds && rp }; } expr ExtractElement::getTypeConstraints(const Function &f) const { - return Value::getTypeConstraints() && + return Value::getTypeConstraints(f) && v->getType().enforceVectorType([&](auto &ty) { return ty == getType(); }) && idx->getType().enforceIntType(); @@ -4922,14 +4961,15 @@ void InsertElement::print(ostream &os) const { StateValue InsertElement::toSMT(State &s) const { auto &[iv, ip] = s[*idx]; auto vty = static_cast(v->getType().getAsAggregateType()); - expr inbounds = iv.ult(vty->numElementsConst()); - auto [rv, rp] = vty->update(s[*v], s[*e], iv); - return { std::move(rv), expr::mkIf(ip && inbounds, std::move(rp), - vty->getDummyValue(false).non_poison) }; + expr inbounds = iv.ult(vty->numElementsConst(s.getVscale())); + auto [rv, rp] = vty->update(s[*v], s[*e], iv, s.getVscale()); + return {std::move(rv), + expr::mkIf(ip && inbounds, std::move(rp), + vty->getDummyValue(false, s.getVscale()).non_poison)}; } expr InsertElement::getTypeConstraints(const Function &f) const { - return Value::getTypeConstraints() && + return Value::getTypeConstraints(f) && getType() == v->getType() && v->getType().enforceVectorType([&](auto &ty) { return ty == e->getType(); }) && @@ -4968,23 +5008,23 @@ void ShuffleVector::print(ostream &os) const { StateValue ShuffleVector::toSMT(State &s) const { auto vty = v1->getType().getAsAggregateType(); - auto sz = vty->numElementsConst(); + auto sz = vty->numElementsConst(s.getVscale()); vector vals; for (auto m : mask) { if (m >= 2 * sz) { - vals.emplace_back(vty->getChild(0).getDummyValue(false)); + vals.emplace_back(vty->getChild(0).getDummyValue(false, s.getVscale())); } else { auto *vect = &s[m < sz ? *v1 : *v2]; - vals.emplace_back(vty->extract(*vect, m % sz)); + vals.emplace_back(vty->extract(*vect, m % sz, s.getVscale())); } } - return getType().getAsAggregateType()->aggregateVals(vals); + return getType().getAsAggregateType()->aggregateVals(vals, s.getVscale()); } expr ShuffleVector::getTypeConstraints(const Function &f) const { - return Value::getTypeConstraints() && + return Value::getTypeConstraints(f) && getType().enforceVectorTypeSameChildTy(v1->getType()) && getType().getAsAggregateType()->numElements() == mask.size() && v1->getType().enforceVectorType() && diff --git a/ir/instr.h b/ir/instr.h index 53efe0b75..675b4a6c4 100644 --- a/ir/instr.h +++ b/ir/instr.h @@ -23,8 +23,7 @@ class Instr : public Value { virtual bool propagatesPoison() const = 0; virtual bool hasSideEffects() const = 0; virtual bool isTerminator() const; - smt::expr getTypeConstraints() const override; - virtual smt::expr getTypeConstraints(const Function &f) const = 0; + smt::expr getTypeConstraints(const Function &f) const override; virtual std::unique_ptr dup(Function &f, const std::string &suffix) const = 0; }; @@ -377,7 +376,7 @@ class ExtractValue final : public Instr { public: ExtractValue(Type &type, std::string &&name, Value &val) : Instr(type, std::move(name)), val(&val) {} - + const auto& getIdxs() const { return idxs; } void addIdx(unsigned idx); @@ -399,7 +398,7 @@ class InsertValue final : public Instr { public: InsertValue(Type &type, std::string &&name, Value &val, Value &elt) : Instr(type, std::move(name)), val(&val), elt(&elt) {} - + const auto& getIdxs() const { return idxs; } void addIdx(unsigned idx); @@ -536,7 +535,7 @@ class Phi final : public Instr { void print(std::ostream &os) const override; StateValue toSMT(State &s) const override; smt::expr getTypeConstraints(const Function &f) const override; - std::unique_ptr + std::unique_ptr dup(Function &f, const std::string &suffix) const override; }; @@ -1277,7 +1276,7 @@ class ShuffleVector final : public Instr { ShuffleVector(Type &type, std::string &&name, Value &v1, Value &v2, std::vector mask) : Instr(type, std::move(name)), v1(&v1), v2(&v2), mask(std::move(mask)) {} - + const auto& getMask() const { return mask; } std::vector operands() const override; bool propagatesPoison() const override; diff --git a/ir/memory.cpp b/ir/memory.cpp index 326e9012f..1373adaaf 100644 --- a/ir/memory.cpp +++ b/ir/memory.cpp @@ -589,7 +589,8 @@ static vector valueToBytes(const StateValue &val, const Type &fromType, for (unsigned i = 0; i < bytesize; ++i) bytes.emplace_back(mem, StateValue(expr(p()), expr(val.non_poison)), i); } else { - assert(!fromType.isAggregateType() || isNonPtrVector(fromType)); + assert(!fromType.isAggregateType() || + isNonPtrVector(fromType, s.getVscale())); StateValue bvval = fromType.toInt(s, val); unsigned bitsize = bvval.bits(); unsigned bytesize = divide_up(bitsize, bits_byte); @@ -674,7 +675,8 @@ static StateValue bytesToValue(const Memory &m, const vector &bytes, return { std::move(loaded_ptr), std::move(non_poison) }; } else { - assert(!toType.isAggregateType() || isNonPtrVector(toType)); + assert(!toType.isAggregateType() || + isNonPtrVector(toType, m.getState().getVscale())); auto bitsize = toType.bits(); assert(divide_up(bitsize, bits_byte) == bytes.size()); @@ -699,7 +701,8 @@ static StateValue bytesToValue(const Memory &m, const vector &bytes, val = first ? std::move(v) : v.concat(val); first = false; } - return toType.fromInt(val.trunc(bitsize, toType.np_bits(true))); + return toType.fromInt(val.trunc(bitsize, toType.np_bits(true)), + m.getState().getVscale()); } } @@ -2230,20 +2233,20 @@ void Memory::free(const expr &ptr, bool unconstrained) { } } -unsigned Memory::getStoreByteSize(const Type &ty) { +unsigned Memory::getStoreByteSize(const Type &ty, expr vscaleRange) { assert(bits_program_pointer != 0); if (ty.isPtrType()) return divide_up(bits_program_pointer, 8); auto aty = ty.getAsAggregateType(); - if (aty && !isNonPtrVector(ty)) { + if (aty && !isNonPtrVector(ty, vscaleRange)) { unsigned sz = 0; - for (unsigned i = 0, e = aty->numElementsConst(); i < e; ++i) - sz += getStoreByteSize(aty->getChild(i)); + for (unsigned i = 0, e = aty->numElementsConst(vscaleRange); i < e; ++i) + sz += getStoreByteSize(aty->getChild(i), vscaleRange); return sz; } - return divide_up(ty.bits(), 8); + return divide_up(ty.bits(vscaleRange), 8); } void Memory::store(const StateValue &v, const Type &type, unsigned offset0, @@ -2251,20 +2254,23 @@ void Memory::store(const StateValue &v, const Type &type, unsigned offset0, unsigned bytesz = bits_byte / 8; auto aty = type.getAsAggregateType(); - if (aty && !isNonPtrVector(type)) { + if (aty && !isNonPtrVector(type, state->getVscale())) { unsigned byteofs = 0; - for (unsigned i = 0, e = aty->numElementsConst(); i < e; ++i) { + for (unsigned i = 0, e = aty->numElementsConst(state->getVscale()); i < e; + ++i) { auto &child = aty->getChild(i); - if (child.bits() == 0) + if (child.bits(state->getVscale()) == 0) continue; - store(aty->extract(v, i), child, offset0 + byteofs, data); - byteofs += getStoreByteSize(child); + store(aty->extract(v, i, state->getVscale()), child, offset0 + byteofs, + data); + byteofs += getStoreByteSize(child, state->getVscale()); } - assert(byteofs == getStoreByteSize(type)); + assert(byteofs == getStoreByteSize(type, state->getVscale())); } else { vector bytes = valueToBytes(v, type, *this, *state); - assert(!v.isValid() || bytes.size() * bytesz == getStoreByteSize(type)); + assert(!v.isValid() || + bytes.size() * bytesz == getStoreByteSize(type, state->getVscale())); for (unsigned i = 0, e = bytes.size(); i < e; ++i) { unsigned offset = little_endian ? i * bytesz : (e - i - 1) * bytesz; @@ -2280,7 +2286,8 @@ void Memory::store(const expr &p, const StateValue &v, const Type &type, // initializer stores are ok by construction if (!state->isInitializationPhase()) - state->addUB(ptr.isDereferenceable(getStoreByteSize(type), align, true)); + state->addUB(ptr.isDereferenceable( + getStoreByteSize(type, state->getVscale()), align, true)); vector> to_store; store(v, type, 0, to_store); @@ -2289,26 +2296,27 @@ void Memory::store(const expr &p, const StateValue &v, const Type &type, StateValue Memory::load(const Pointer &ptr, const Type &type, set &undef, uint64_t align) { - unsigned bytecount = getStoreByteSize(type); + unsigned bytecount = getStoreByteSize(type, state->getVscale()); auto aty = type.getAsAggregateType(); - if (aty && !isNonPtrVector(type)) { + if (aty && !isNonPtrVector(type, state->getVscale())) { vector member_vals; unsigned byteofs = 0; - for (unsigned i = 0, e = aty->numElementsConst(); i < e; ++i) { + for (unsigned i = 0, e = aty->numElementsConst(state->getVscale()); i < e; + ++i) { // Padding is filled with poison. if (aty->isPadding(i)) { - byteofs += getStoreByteSize(aty->getChild(i)); + byteofs += getStoreByteSize(aty->getChild(i), state->getVscale()); continue; } auto ptr_i = ptr + byteofs; auto align_i = gcd(align, byteofs % align); member_vals.emplace_back(load(ptr_i, aty->getChild(i), undef, align_i)); - byteofs += getStoreByteSize(aty->getChild(i)); + byteofs += getStoreByteSize(aty->getChild(i), state->getVscale()); } assert(byteofs == bytecount); - return aty->aggregateVals(member_vals); + return aty->aggregateVals(member_vals, state->getVscale()); } bool is_ptr = type.isPtrType(); @@ -2350,7 +2358,8 @@ Memory::load(const expr &p, const Type &type, uint64_t align) { assert(!memory_unused()); Pointer ptr(*this, p); - auto ubs = ptr.isDereferenceable(getStoreByteSize(type), align, false); + auto ubs = ptr.isDereferenceable(getStoreByteSize(type, state->getVscale()), + align, false); set undef_vars; auto ret = load(ptr, type, undef_vars, align); return { state->rewriteUndef(std::move(ret), undef_vars), std::move(ubs) }; @@ -2514,7 +2523,8 @@ void Memory::copy(const Pointer &src, const Pointer &dst) { void Memory::fillPoison(const expr &bid) { Pointer p(*this, bid, expr::mkUInt(0, bits_for_offset)); expr blksz = p.blockSizeAligned(); - memset(std::move(p).release(), IntType("i8", 8).getDummyValue(false), + memset(std::move(p).release(), + IntType("i8", 8).getDummyValue(false, state->getVscale()), std::move(blksz), bits_byte / 8, {}, false); } diff --git a/ir/memory.h b/ir/memory.h index f41857627..79171f7fc 100644 --- a/ir/memory.h +++ b/ir/memory.h @@ -7,11 +7,11 @@ #include "ir/functions.h" #include "ir/pointer.h" #include "ir/state_value.h" -#include "ir/type.h" #include "smt/expr.h" #include "smt/exprs.h" #include "util/spaceship.h" #include +#include #include #include #include @@ -321,7 +321,7 @@ class Memory { // are not checked. void free(const smt::expr &ptr, bool unconstrained); - static unsigned getStoreByteSize(const Type &ty); + static unsigned getStoreByteSize(const Type &ty, smt::expr vscaleRange); void store(const smt::expr &ptr, const StateValue &val, const Type &type, uint64_t align, const std::set &undef_vars); std::pair> diff --git a/ir/precondition.cpp b/ir/precondition.cpp index 7b66b2a47..0a9959791 100644 --- a/ir/precondition.cpp +++ b/ir/precondition.cpp @@ -14,7 +14,7 @@ using namespace util; namespace IR { -expr Predicate::getTypeConstraints() const { +expr Predicate::getTypeConstraints(const Function &f) const { return true; } @@ -145,10 +145,10 @@ expr FnPred::toSMT(State &s) const { return r; } -expr FnPred::getTypeConstraints() const { +expr FnPred::getTypeConstraints(const Function &f) const { expr r(true); for (auto a : args) { - r &= a->getTypeConstraints(); + r &= a->getTypeConstraints(f); } switch (fn) { case AddNSW: @@ -210,8 +210,8 @@ expr CmpPred::toSMT(State &s) const { return { ap && bp && std::move(r) }; } -expr CmpPred::getTypeConstraints() const { - return lhs.getTypeConstraints() && +expr CmpPred::getTypeConstraints(const Function &f) const { + return lhs.getTypeConstraints(f) && lhs.getType().enforceIntType() && lhs.getType() == rhs.getType(); } diff --git a/ir/precondition.h b/ir/precondition.h index 808d726bd..7e3877a86 100644 --- a/ir/precondition.h +++ b/ir/precondition.h @@ -17,7 +17,7 @@ class Predicate { public: virtual void print(std::ostream &os) const = 0; virtual smt::expr toSMT(State &s) const = 0; - virtual smt::expr getTypeConstraints() const; + virtual smt::expr getTypeConstraints(const Function &f) const; virtual void fixupTypes(const smt::Model &m); virtual ~Predicate() {} }; @@ -49,7 +49,7 @@ class FnPred final : public Predicate { FnPred(std::string_view name, std::vector &&args); void print(std::ostream &os) const override; smt::expr toSMT(State &s) const override; - smt::expr getTypeConstraints() const override; + smt::expr getTypeConstraints(const Function &f) const override; void fixupTypes(const smt::Model &m) override; }; @@ -73,7 +73,7 @@ class CmpPred final : public Predicate { void print(std::ostream &os) const override; smt::expr toSMT(State &s) const override; - smt::expr getTypeConstraints() const override; + smt::expr getTypeConstraints(const Function &f) const override; void fixupTypes(const smt::Model &m) override; }; diff --git a/ir/state.cpp b/ir/state.cpp index 3e940eb65..181ba3707 100644 --- a/ir/state.cpp +++ b/ir/state.cpp @@ -255,7 +255,20 @@ State::State(const Function &f, bool source) : f(f), source(source), memory(*this), fp_rounding_mode(expr::mkVar("fp_rounding_mode", 3)), fp_denormal_mode(expr::mkVar("fp_denormal_mode", 2)), - return_val(DisjointExpr(f.getType().getDummyValue(false))) {} + vscale_data(vscaleFromAttr(f.getFnAttrs().vscaleRange)), + return_val(DisjointExpr(f.getType().getDummyValue(false, vscale_data))) {} + +expr State::vscaleFromAttr( + std::optional> vscaleAttr) { + if (vscaleAttr) { + auto [low, high] = *vscaleAttr; + unsigned r = 0; + for (unsigned i = ilog2(low); i <= ilog2(high); ++i) + r |= 1 << i; + return expr::mkUInt(r, var_vector_elements); + } + return expr::mkVscaleMin(); +} void State::resetGlobals() { Memory::resetGlobals(); @@ -296,13 +309,15 @@ static expr eq_except_padding(const Memory &m, const Type &ty, const expr &e1, StateValue sv1{expr(e1), expr()}; StateValue sv2{expr(e2), expr()}; expr result = true; + auto vscaleRange = m.getState().getVscale(); - for (unsigned i = 0; i < aty->numElementsConst(); ++i) { + for (unsigned i = 0; i < aty->numElementsConst(vscaleRange); ++i) { if (aty->isPadding(i)) continue; - result &= eq_except_padding(m, aty->getChild(i), aty->extract(sv1, i).value, - aty->extract(sv2, i).value, ptr_compare); + result &= eq_except_padding( + m, aty->getChild(i), aty->extract(sv1, i, vscaleRange).value, + aty->extract(sv2, i, vscaleRange).value, ptr_compare); } return result; } @@ -596,7 +611,8 @@ const StateValue& State::getAndAddUndefs(const Value &val) { return v; } -static expr not_poison_except_padding(const Type &ty, const expr &np) { +static expr not_poison_except_padding(const State &s, const Type &ty, + const expr &np) { const auto *aty = ty.getAsAggregateType(); if (!aty) { assert(!np.isValid() || np.isBool()); @@ -606,12 +622,12 @@ static expr not_poison_except_padding(const Type &ty, const expr &np) { StateValue sv{expr(), expr(np)}; expr result = true; - for (unsigned i = 0; i < aty->numElementsConst(); ++i) { + for (unsigned i = 0; i < aty->numElementsConst(s.getVscale()); ++i) { if (aty->isPadding(i)) continue; - result &= not_poison_except_padding(aty->getChild(i), - aty->extract(sv, i).non_poison); + result &= not_poison_except_padding( + s, aty->getChild(i), aty->extract(sv, i, s.getVscale()).non_poison); } return result; } @@ -658,7 +674,7 @@ State::getAndAddPoisonUB(const Value &val, bool undef_ub_too, } // If val is an aggregate, all elements should be non-poison - addUB(not_poison_except_padding(val.getType(), sv.non_poison)); + addUB(not_poison_except_padding(*this, val.getType(), sv.non_poison)); } check_enough_tmp_slots(); @@ -697,7 +713,7 @@ bool State::isAsmMode() const { expr State::getPath(BasicBlock &bb) const { if (&f.getFirstBB() == &bb) return true; - + auto I = predecessor_data.find(&bb); if (I == predecessor_data.end()) return false; // Block is unreachable @@ -1110,7 +1126,8 @@ State::addFnCall(const string &name, vector &&inputs, : analysis.ranges_fn_calls; if (ret_arg_ty && (*ret_arg_ty == out_type).isFalse()) { - ret_arg = out_type.fromInt(ret_arg_ty->toInt(*this, std::move(ret_arg))); + ret_arg = out_type.fromInt(ret_arg_ty->toInt(*this, std::move(ret_arg)), + getVscale()); } // source may create new fn symbols, target just references src symbols @@ -1143,8 +1160,8 @@ State::addFnCall(const string &name, vector &&inputs, return { std::move(val), mk_np(true) }; } - if (!hasPtr(ty)) { - auto dummy = ty.getDummyValue(true); + if (!hasPtr(ty, getVscale())) { + auto dummy = ty.getDummyValue(true, getVscale()); return { expr::mkFreshVar(name.c_str(), dummy.value), mk_np(std::move(dummy.non_poison)) }; } @@ -1152,11 +1169,12 @@ State::addFnCall(const string &name, vector &&inputs, assert(ty.isAggregateType()); auto agg = ty.getAsAggregateType(); vector vals; - for (unsigned i = 0, e = agg->numElementsConst(); i != e; ++i) { + for (unsigned i = 0, e = agg->numElementsConst(getVscale()); i != e; + ++i) { if (!agg->isPadding(i)) vals.emplace_back(mk_output(agg->getChild(i))); } - return agg->aggregateVals(vals); + return agg->aggregateVals(vals, getVscale()); }; output = ret_arg_ty ? std::move(ret_arg) : mk_output(out_type); @@ -1186,14 +1204,14 @@ State::addFnCall(const string &name, vector &&inputs, } } - I->second - = { std::move(output), expr::mkFreshVar((name + "#ub").c_str(), false), - (noret || willret) - ? expr(noret) - : expr::mkFreshVar((name + "#noreturn").c_str(), false), - memory.mkCallState(name, attrs.has(FnAttrs::NoFree), - I->first.args_ptr.size(), memaccess), - std::move(ret_data) }; + I->second = {std::move(output), + expr::mkFreshVar((name + "#ub").c_str(), false), + (noret || willret) + ? expr(noret) + : expr::mkFreshVar((name + "#noreturn").c_str(), false), + memory.mkCallState(name, attrs.has(FnAttrs::NoFree), + I->first.args_ptr.size(), memaccess), + std::move(ret_data)}; // add equality constraints between source's function calls for (auto II = calls_fn.begin(), E = calls_fn.end(); II != E; ++II) { @@ -1245,17 +1263,18 @@ State::addFnCall(const string &name, vector &&inputs, std::move(val.non_poison) }; } - if (!hasPtr(ty)) + if (!hasPtr(ty, getVscale())) return std::move(val); assert(ty.isAggregateType()); auto agg = ty.getAsAggregateType(); vector vals; - for (unsigned i = 0, e = agg->numElementsConst(); i != e; ++i) { + for (unsigned i = 0, e = agg->numElementsConst(getVscale()); i != e; + ++i) { vals.emplace_back( - mk_output(agg->getChild(i), agg->extract(val, i))); + mk_output(agg->getChild(i), agg->extract(val, i, getVscale()))); } - return agg->aggregateVals(vals); + return agg->aggregateVals(vals, getVscale()); }; retval = mk_output(out_type, std::move(d.retval)); } else @@ -1268,7 +1287,7 @@ State::addFnCall(const string &name, vector &&inputs, fn_call_qvars.emplace(std::move(qvar)); } else { addUB(expr(false)); - retval = out_type.getDummyValue(false); + retval = out_type.getDummyValue(false, getVscale()); } } diff --git a/ir/state.h b/ir/state.h index 37ce7ff0b..fb59ef012 100644 --- a/ir/state.h +++ b/ir/state.h @@ -182,6 +182,9 @@ class State { std::array tmp_values; unsigned i_tmp_values = 0; // next available position in tmp_values + // for scalable vectors + smt::expr vscale_data; + void check_enough_tmp_slots(); // return_domain: a boolean expression describing return condition @@ -303,6 +306,9 @@ class State { unsigned indirect_call_hash); auto& getVarArgsData() { return var_args_data.data; } + const smt::expr &getVscale() const { return vscale_data; } + static smt::expr + vscaleFromAttr(std::optional> vscaleAttr); void doesApproximation(std::string &&name, std::optional e = {}); auto& getApproximations() const { return used_approximations; } diff --git a/ir/type.cpp b/ir/type.cpp index 404fae097..d14421cd9 100644 --- a/ir/type.cpp +++ b/ir/type.cpp @@ -2,6 +2,8 @@ // Distributed under the MIT license that can be found in the LICENSE file. #include "ir/type.h" +#include "ir/attrs.h" +#include "ir/function.h" #include "ir/globals.h" #include "ir/state.h" #include "smt/solver.h" @@ -17,14 +19,12 @@ using namespace std; static constexpr unsigned var_type_bits = 3; static constexpr unsigned var_bw_bits = 11; -static constexpr unsigned var_vector_elements = 16; - namespace IR { VoidType Type::voidTy; -unsigned Type::np_bits(bool fromInt) const { +unsigned Type::np_bits(bool fromInt, expr) const { if (!fromInt) return 1; auto bw = bits(); @@ -242,7 +242,7 @@ expr Type::toInt(State &s, expr v) const { } StateValue Type::toInt(State &s, StateValue v) const { - auto bw = np_bits(true); + auto bw = np_bits(true, s.getVscale()); return { toInt(s, std::move(v.value)), expr::mkIf(v.non_poison, expr::mkInt(-1, bw), expr::mkUInt(0, bw)) }; } @@ -251,7 +251,7 @@ expr Type::fromInt(expr e) const { return fromBV(std::move(e)); } -StateValue Type::fromInt(StateValue v) const { +StateValue Type::fromInt(StateValue v, expr) const { return { fromInt(std::move(v.value)), v.non_poison.isBool() ? expr(v.non_poison) @@ -264,7 +264,7 @@ expr Type::combine_poison(const expr &boolean, const expr &orig) const { } StateValue Type::mkUndef(State &s) const { - auto val = getDummyValue(true); + auto val = getDummyValue(true, s.getVscale()); expr var = expr::mkFreshVar("undef", val.value); s.addUndefVar(expr(var)); return { std::move(var), std::move(val.non_poison) }; @@ -289,15 +289,15 @@ string Type::toString() const { Type::~Type() {} -unsigned VoidType::bits() const { +unsigned VoidType::bits(expr) const { UNREACHABLE(); } -StateValue VoidType::getDummyValue(bool non_poison) const { +StateValue VoidType::getDummyValue(bool non_poison, expr) const { return { false, non_poison }; } -expr VoidType::getTypeConstraints() const { +expr VoidType::getTypeConstraints(const Function &f) const { return true; } @@ -333,15 +333,15 @@ unsigned IntType::maxSubBitAccess() const { return 0; } -unsigned IntType::bits() const { +unsigned IntType::bits(expr) const { return bitwidth; } -StateValue IntType::getDummyValue(bool non_poison) const { +StateValue IntType::getDummyValue(bool non_poison, expr) const { return { expr::mkUInt(0, bits()), non_poison }; } -expr IntType::getTypeConstraints() const { +expr IntType::getTypeConstraints(const Function &f) const { // since size cannot be unbounded, limit it between 1 and 64 bits if undefined auto bw = sizeVar(); auto r = bw != 0; @@ -525,7 +525,7 @@ expr FloatType::isNaN(const expr &v, bool signalling) const { } } -unsigned FloatType::bits() const { +unsigned FloatType::bits(expr) const { assert(fpType != Unknown); return float_sizes[fpType].first; } @@ -566,11 +566,11 @@ expr FloatType::sizeVar() const { return defined ? expr::mkUInt(bits(), var_bw_bits) : Type::sizeVar(); } -StateValue FloatType::getDummyValue(bool non_poison) const { +StateValue FloatType::getDummyValue(bool non_poison, expr) const { return { expr::mkUInt(0, bits()), non_poison }; } -expr FloatType::getTypeConstraints() const { +expr FloatType::getTypeConstraints(const Function &f) const { if (defined) return true; @@ -663,19 +663,19 @@ expr PtrType::ASVar() const { return defined ? expr::mkUInt(addr_space, 2) : var("as", 2); } -unsigned PtrType::bits() const { +unsigned PtrType::bits(expr) const { return Pointer::totalBits(); } -unsigned PtrType::np_bits(bool fromInt) const { +unsigned PtrType::np_bits(bool fromInt, expr) const { return 1; } -StateValue PtrType::getDummyValue(bool non_poison) const { +StateValue PtrType::getDummyValue(bool non_poison, expr) const { return { expr::mkUInt(0, bits()), non_poison }; } -expr PtrType::getTypeConstraints() const { +expr PtrType::getTypeConstraints(const Function &f) const { return sizeVar() == bits(); } @@ -713,7 +713,7 @@ expr PtrType::fromInt(expr v) const { return v; } -StateValue PtrType::fromInt(StateValue v) const { +StateValue PtrType::fromInt(StateValue v, expr) const { return Type::fromInt(std::move(v)); } @@ -781,13 +781,15 @@ AggregateType::AggregateType(string &&name, vector &&vchildren, elements = children.size(); } -expr AggregateType::numElements() const { - return defined ? expr::mkUInt(elements, var_vector_elements) : - var("elements", var_vector_elements); +expr AggregateType::numElements(expr vscaleRange) const { + return defined + ? expr::mkUInt(numElementsConst(vscaleRange), var_vector_elements) + : var("elements", var_vector_elements); } -unsigned AggregateType::numPaddingsConst() const { - return is_padding.empty() ? 0 : countPaddings(is_padding.size() - 1); +unsigned AggregateType::numPaddingsConst(expr vscaleRange) const { + unsigned elems = numElementsConst(vscaleRange); + return elems ? countPaddings(elems - 1) : 0; } unsigned AggregateType::countPaddings(unsigned to_idx) const { @@ -797,22 +799,24 @@ unsigned AggregateType::countPaddings(unsigned to_idx) const { return count; } -expr AggregateType::numElementsExcludingPadding() const { - auto elems = numElements(); - return numElements() - expr::mkInt(numPaddingsConst(), elems); +expr AggregateType::numElementsExcludingPadding(expr vscaleRange) const { + auto elems = numElements(vscaleRange); + return elems - expr::mkInt(numPaddingsConst(vscaleRange), elems); } -StateValue AggregateType::aggregateVals(const vector &vals) const { - assert(vals.size() + numPaddingsConst() == elements); +StateValue AggregateType::aggregateVals(const vector &vals, + expr vscaleRange) const { + unsigned elems = numElementsConst(vscaleRange); + assert(vals.size() + numPaddingsConst(vscaleRange) == elems); // structs can be empty - if (elements == 0) + if (elems == 0) return { expr::mkUInt(0, 1), expr::mkUInt(0, 1) }; StateValue v; bool first = true; unsigned val_idx = 0; - for (unsigned idx = 0; idx < elements; ++idx) { - if (children[idx]->bits() == 0) { + for (unsigned idx = 0; idx < elems; ++idx) { + if (children[idx]->bits(vscaleRange) == 0) { assert(!isPadding(idx)); val_idx++; continue; @@ -820,7 +824,7 @@ StateValue AggregateType::aggregateVals(const vector &vals) const { StateValue vv; if (isPadding(idx)) - vv = children[idx]->getDummyValue(false); + vv = children[idx]->getDummyValue(false, vscaleRange); else vv = vals[val_idx++]; vv = children[idx]->toBV(std::move(vv)); @@ -831,26 +835,26 @@ StateValue AggregateType::aggregateVals(const vector &vals) const { } StateValue AggregateType::extract(const StateValue &val, unsigned index, - bool fromInt) const { + expr vscaleRange, bool fromInt) const { unsigned total_value = 0, total_np = 0; for (unsigned i = 0; i < index; ++i) { - total_value += children[i]->bits(); + total_value += children[i]->bits(vscaleRange); total_np += children[i]->np_bits(fromInt); } unsigned h_val, l_val, h_np, l_np; if (fromInt && little_endian) { - h_val = total_value + children[index]->bits() - 1; + h_val = total_value + children[index]->bits(vscaleRange) - 1; l_val = total_value; h_np = total_np + children[index]->np_bits(fromInt) - 1; l_np = total_np; } else { - unsigned high_val = bits() - total_value; + unsigned high_val = bits(vscaleRange) - total_value; h_val = high_val - 1; - l_val = high_val - children[index]->bits(); + l_val = high_val - children[index]->bits(vscaleRange); - unsigned high_np = np_bits(fromInt) - total_np; + unsigned high_np = np_bits(fromInt, vscaleRange) - total_np; h_np = high_np - 1; l_np = high_np - children[index]->np_bits(fromInt); } @@ -861,43 +865,47 @@ StateValue AggregateType::extract(const StateValue &val, unsigned index, children[index]->fromBV(std::move(sv)); } -unsigned AggregateType::bits() const { - if (elements == 0) +unsigned AggregateType::bits(expr vscaleRange) const { + unsigned elems = numElementsConst(vscaleRange); + if (elems == 0) // It is set as 1 because zero-width bitvector is invalid. return 1; unsigned bw = 0; - for (unsigned i = 0; i < elements; ++i) { - bw += children[i]->bits(); + for (unsigned i = 0; i < elems; ++i) { + bw += children[i]->bits(vscaleRange); } return bw; } -unsigned AggregateType::np_bits(bool fromInt) const { - if (elements == 0) +unsigned AggregateType::np_bits(bool fromInt, expr vscaleRange) const { + unsigned elems = numElementsConst(vscaleRange); + if (elems == 0) // It is set as 1 because zero-width bitvector is invalid. return 1; unsigned bw = 0; - for (unsigned i = 0; i < elements; ++i) { + for (unsigned i = 0; i < elems; ++i) { bw += children[i]->np_bits(fromInt); } return bw; } -StateValue AggregateType::getDummyValue(bool non_poison) const { +StateValue AggregateType::getDummyValue(bool non_poison, + expr vscaleRange) const { + unsigned elems = numElementsConst(vscaleRange); vector vals; - for (unsigned i = 0; i < elements; ++i) { + for (unsigned i = 0; i < elems; ++i) { if (!isPadding(i)) - vals.emplace_back(children[i]->getDummyValue(non_poison)); + vals.emplace_back(children[i]->getDummyValue(non_poison, vscaleRange)); } - return aggregateVals(vals); + return aggregateVals(vals, vscaleRange); } -expr AggregateType::getTypeConstraints() const { +expr AggregateType::getTypeConstraints(const Function &f) const { expr r(true), elems = numElements(); for (unsigned i = 0, e = children.size(); i != e; ++i) { - r &= elems.ugt(i).implies(children[i]->getTypeConstraints()); + r &= elems.ugt(i).implies(children[i]->getTypeConstraints(f)); } if (!defined) r &= elems.ule(4); @@ -970,13 +978,15 @@ expr AggregateType::toInt(State &s, expr v) const { } StateValue AggregateType::toInt(State &s, StateValue v) const { + unsigned elems = numElementsConst(s.getVscale()); + // structs can be empty - if (elements == 0) + if (elems == 0) return { expr::mkUInt(0, 1), expr::mkUInt(1, 1) }; StateValue ret; - for (unsigned i = 0; i < elements; ++i) { - auto vv = children[i]->toInt(s, extract(v, i)); + for (unsigned i = 0; i < elems; ++i) { + auto vv = children[i]->toInt(s, extract(v, i, s.getVscale())); ret = i == 0 ? std::move(vv) : (little_endian ? vv.concat(ret) : ret.concat(vv)); } return ret; @@ -986,10 +996,11 @@ expr AggregateType::fromInt(expr v) const { UNREACHABLE(); } -StateValue AggregateType::fromInt(StateValue v) const { +StateValue AggregateType::fromInt(StateValue v, expr vscaleRange) const { + unsigned elems = numElementsConst(vscaleRange); vector child_vals; - for (unsigned i = 0; i < elements; ++i) - child_vals.emplace_back(extract(v, i, true)); + for (unsigned i = 0; i < elems; ++i) + child_vals.emplace_back(extract(v, i, vscaleRange, true)); return this->aggregateVals(child_vals); } @@ -997,9 +1008,11 @@ pair AggregateType::refines(State &src_s, State &tgt_s, const StateValue &src, const StateValue &tgt) const { set poison, value; - for (unsigned i = 0; i < elements; ++i) { - auto [p, v] = children[i]->refines(src_s, tgt_s, extract(src, i), - extract(tgt, i)); + expr vscaleRange = src_s.getVscale(); + for (unsigned i = 0; i < numElementsConst(vscaleRange); ++i) { + auto [p, v] = + children[i]->refines(src_s, tgt_s, extract(src, i, vscaleRange), + extract(tgt, i, vscaleRange)); poison.insert(std::move(p)); value.insert(std::move(v)); } @@ -1007,12 +1020,13 @@ AggregateType::refines(State &src_s, State &tgt_s, const StateValue &src, } StateValue AggregateType::mkUndef(State &s) const { + unsigned elems = numElementsConst(s.getVscale()); vector vals; - for (unsigned i = 0; i < elements; ++i) { + for (unsigned i = 0; i < elems; ++i) { if (!isPadding(i)) vals.emplace_back(children[i]->mkUndef(s)); } - return aggregateVals(vals); + return aggregateVals(vals, s.getVscale()); } expr AggregateType::mkInput(State &s, const char *name, @@ -1077,30 +1091,39 @@ void ArrayType::print(ostream &os) const { } } - -VectorType::VectorType(string &&name, unsigned elements, Type &elementTy) - : AggregateType(std::move(name), false) { - assert(elements != 0); - this->elements = elements; +VectorType::VectorType(string &&name, unsigned minElems, Type &elementTy, + bool isScalableTy) + : AggregateType(std::move(name), false) { + assert(minElems != 0); + this->isScalableTy = isScalableTy; + this->elements = minElems; defined = true; - children.resize(elements, &elementTy); - is_padding.resize(elements, false); + unsigned scaleFactor = isScalableTy ? var_vector_max_vscale : 1; + children.resize(elements * scaleFactor, &elementTy); + is_padding.resize(elements * scaleFactor, false); +} + +unsigned VectorType::numElementsConst(expr vscaleRange) const { + unsigned scaleFactor = + isScalable() ? 1 << (vscaleRange.active_bits() - 1) : 1; + return elements * scaleFactor; } StateValue VectorType::extract(const StateValue &vector, - const expr &index) const { + const expr &index, expr vscaleRange) const { + unsigned elems = numElementsConst(vscaleRange); auto &elementTy = *children[0]; - unsigned bw_elem = elementTy.bits(); - unsigned bw_val = bw_elem * elements; + unsigned bw_elem = elementTy.bits(vscaleRange); + unsigned bw_val = bw_elem * elems; expr idx_v = index.zextOrTrunc(bw_val) * expr::mkUInt(bw_elem, bw_val); - unsigned h_val = elements * bw_elem - 1; - unsigned l_val = (elements - 1) * bw_elem; + unsigned h_val = elems * bw_elem - 1; + unsigned l_val = (elems - 1) * bw_elem; - unsigned bw_np_elem = elementTy.np_bits(false); - unsigned bw_np = bw_np_elem * elements; + unsigned bw_np_elem = elementTy.np_bits(false, vscaleRange); + unsigned bw_np = bw_np_elem * elems; expr idx_np = index.zextOrTrunc(bw_np) * expr::mkUInt(bw_np_elem, bw_np); - unsigned h_np = elements * bw_np_elem - 1; - unsigned l_np = (elements - 1) * bw_np_elem; + unsigned h_np = elems * bw_np_elem - 1; + unsigned l_np = (elems - 1) * bw_np_elem; return elementTy.fromBV({(vector.value << idx_v).extract(h_val, l_val), (vector.non_poison << idx_np).extract(h_np, l_np)}); @@ -1108,22 +1131,24 @@ StateValue VectorType::extract(const StateValue &vector, StateValue VectorType::update(const StateValue &vector, const StateValue &val, - const expr &index) const { + const expr &index, + expr vscaleRange) const { auto &elementTy = *children[0]; StateValue val_bv = elementTy.toBV(val); + unsigned elems = numElementsConst(vscaleRange); - if (elements == 1) + if (elems == 1) return val_bv; - unsigned bw_elem = elementTy.bits(); - unsigned bw_val = bw_elem * elements; + unsigned bw_elem = elementTy.bits(vscaleRange); + unsigned bw_val = bw_elem * elems; expr idx_v = index.zextOrTrunc(bw_val) * expr::mkUInt(bw_elem, bw_val); expr fill_v = expr::mkUInt(0, bw_val - bw_elem); expr mask_v = ~expr::mkInt(-1, bw_elem).concat(fill_v).lshr(idx_v); expr nv_shifted = val_bv.value.concat(fill_v).lshr(idx_v); - unsigned bw_np_elem = elementTy.np_bits(false); - unsigned bw_np = bw_np_elem * elements; + unsigned bw_np_elem = elementTy.np_bits(false, vscaleRange); + unsigned bw_np = bw_np_elem * elems; expr idx_np = index.zextOrTrunc(bw_np) * expr::mkUInt(bw_np_elem, bw_np); expr fill_np = expr::mkUInt(0, bw_np - bw_np_elem); expr mask_np = ~expr::mkInt(-1, bw_np_elem).concat(fill_np).lshr(idx_np); @@ -1133,19 +1158,34 @@ StateValue VectorType::update(const StateValue &vector, (vector.non_poison & mask_np) | np_shifted}); } -expr VectorType::getTypeConstraints() const { +expr VectorType::getTypeConstraints(const Function &f) const { + auto vscaleAttr = f.getFnAttrs().vscaleRange; + if (isScalable()) { + // TODO: if we don't have a vscale_range on the function, fail the type + // check for now. + // If we don't havethe underlying storage for the high range of the vscale, + // fail the type check. + if (!vscaleAttr || vscaleAttr->second > var_vector_max_vscale) + return false; + } + auto &elementTy = *children[0]; - expr r = AggregateType::getTypeConstraints() && + expr vscaleRange = State::vscaleFromAttr(vscaleAttr); + expr elems = numElements(vscaleRange); + expr r = AggregateType::getTypeConstraints(f) && (elementTy.enforceIntType() || elementTy.enforceFloatType() || elementTy.enforcePtrType()) && - numElements() != 0; + elems != 0; // all elements have the same type for (unsigned i = 1, e = children.size(); i != e; ++i) { - r &= numElements().ugt(i).implies(elementTy == *children[i]); + r &= elems.ugt(i).implies(elementTy == *children[i]); } + // TODO: remove once scalable vectors are fully supported. + r &= vscaleRange.isPowerOf2(); + return r; } @@ -1157,6 +1197,16 @@ bool VectorType::isVectorType() const { return true; } +expr VectorType::operator==(const VectorType &rhs) const { + expr res = this->AggregateType::operator==(rhs); + res &= isScalable() == rhs.isScalable(); + return res; +} + +bool VectorType::isScalable() const { + return isScalableTy; +} + expr VectorType::enforceVectorType( const function &enforceElem) const { return enforceElem(*children[0]); @@ -1164,7 +1214,8 @@ expr VectorType::enforceVectorType( void VectorType::print(ostream &os) const { if (elements) - os << '<' << elements << " x " << *children[0] << '>'; + os << '<' << (isScalable() ? "vscale x " : "") << elements << " x " + << *children[0] << '>'; } @@ -1246,26 +1297,26 @@ SymbolicType::SymbolicType(string &&name, unsigned type_mask) ret = ret.isValid() ? expr::mkIf(isStruct(), s->call, ret) : s->call; \ return ret -unsigned SymbolicType::bits() const { - DISPATCH(bits(), UNREACHABLE()); +unsigned SymbolicType::bits(expr dummy) const { + DISPATCH(bits(dummy), UNREACHABLE()); } -unsigned SymbolicType::np_bits(bool fromInt) const { - DISPATCH(np_bits(fromInt), UNREACHABLE()); +unsigned SymbolicType::np_bits(bool fromInt, expr dummy) const { + DISPATCH(np_bits(fromInt, dummy), UNREACHABLE()); } -StateValue SymbolicType::getDummyValue(bool non_poison) const { - DISPATCH(getDummyValue(non_poison), UNREACHABLE()); +StateValue SymbolicType::getDummyValue(bool non_poison, expr dummy) const { + DISPATCH(getDummyValue(non_poison, dummy), UNREACHABLE()); } -expr SymbolicType::getTypeConstraints() const { +expr SymbolicType::getTypeConstraints(const Function &fn) const { expr c(false); - if (i) c |= isInt() && i->getTypeConstraints(); - if (f) c |= isFloat() && f->getTypeConstraints(); - if (p) c |= isPtr() && p->getTypeConstraints(); - if (a) c |= isArray() && a->getTypeConstraints(); - if (v) c |= isVector() && v->getTypeConstraints(); - if (s) c |= isStruct() && s->getTypeConstraints(); + if (i) c |= isInt() && i->getTypeConstraints(fn); + if (f) c |= isFloat() && f->getTypeConstraints(fn); + if (p) c |= isPtr() && p->getTypeConstraints(fn); + if (a) c |= isArray() && a->getTypeConstraints(fn); + if (v) c |= isVector() && v->getTypeConstraints(fn); + if (s) c |= isStruct() && s->getTypeConstraints(fn); return c; } @@ -1430,8 +1481,8 @@ expr SymbolicType::fromInt(expr e) const { DISPATCH(fromInt(std::move(e)), UNREACHABLE()); } -StateValue SymbolicType::fromInt(StateValue val) const { - DISPATCH(fromInt(std::move(val)), UNREACHABLE()); +StateValue SymbolicType::fromInt(StateValue val, expr dummy) const { + DISPATCH(fromInt(std::move(val), dummy), UNREACHABLE()); } pair @@ -1463,34 +1514,35 @@ void SymbolicType::print(ostream &os) const { } -bool hasPtr(const Type &t) { +bool hasPtr(const Type &t, expr vscaleRange) { if (t.isPtrType()) return true; if (auto agg = t.getAsAggregateType()) { - for (unsigned i = 0, e = agg->numElementsConst(); i != e; ++i) { - if (hasPtr(agg->getChild(i))) + for (unsigned i = 0, e = agg->numElementsConst(vscaleRange); i != e; ++i) { + if (hasPtr(agg->getChild(i), vscaleRange)) return true; } } return false; } -bool isNonPtrVector(const Type &t) { +bool isNonPtrVector(const Type &t, expr vscaleRange) { auto vty = dynamic_cast(&t); return vty && !vty->getChild(0).isPtrType(); } -unsigned minVectorElemSize(const Type &t) { +unsigned minVectorElemSize(const Type &t, expr vscaleRange) { if (auto agg = t.getAsAggregateType()) { if (t.isVectorType()) { auto &elemTy = agg->getChild(0); - return elemTy.isPtrType() ? IR::bits_program_pointer : elemTy.bits(); + return elemTy.isPtrType() ? IR::bits_program_pointer + : elemTy.bits(vscaleRange); } unsigned val = 0; - for (unsigned i = 0, e = agg->numElementsConst(); i != e; ++i) { - if (auto ch = minVectorElemSize(agg->getChild(i))) { + for (unsigned i = 0, e = agg->numElementsConst(vscaleRange); i != e; ++i) { + if (auto ch = minVectorElemSize(agg->getChild(i), vscaleRange)) { val = val ? gcd(val, ch) : ch; } } @@ -1499,24 +1551,25 @@ unsigned minVectorElemSize(const Type &t) { return 0; } -uint64_t getCommonAccessSize(const IR::Type &ty) { +uint64_t getCommonAccessSize(const IR::Type &ty, expr vscaleRange) { if (auto agg = ty.getAsAggregateType()) { // non-pointer vectors are stored/loaded all at once if (agg->isVectorType()) { auto &elemTy = agg->getChild(0); if (!elemTy.isPtrType()) - return divide_up(agg->numElementsConst() * elemTy.bits(), 8); + return divide_up( + agg->numElementsConst(vscaleRange) * elemTy.bits(vscaleRange), 8); } uint64_t sz = 1; - for (unsigned i = 0, e = agg->numElementsConst(); i != e; ++i) { - auto n = getCommonAccessSize(agg->getChild(i)); + for (unsigned i = 0, e = agg->numElementsConst(vscaleRange); i != e; ++i) { + auto n = getCommonAccessSize(agg->getChild(i), vscaleRange); sz = i == 0 ? n : gcd(sz, n); } return sz; } if (ty.isPtrType()) return IR::bits_program_pointer / 8; - return divide_up(ty.bits(), 8); + return divide_up(ty.bits(vscaleRange), 8); } } diff --git a/ir/type.h b/ir/type.h index b2b02c9e7..5cc775714 100644 --- a/ir/type.h +++ b/ir/type.h @@ -26,6 +26,7 @@ class SymbolicType; class VectorType; class VoidType; class State; +class Function; struct StateValue; class Type { @@ -43,13 +44,17 @@ class Type { public: Type(std::string &&name) : name(std::move(name)) {} - virtual unsigned bits() const = 0; - virtual unsigned np_bits(bool fromInt) const; + virtual unsigned + bits(smt::expr vscaleRange = smt::expr::mkVscaleMin()) const = 0; + virtual unsigned + np_bits(bool fromInt, smt::expr vscaleRange = smt::expr::mkVscaleMin()) const; // to use when one needs the corresponding SMT type - virtual IR::StateValue getDummyValue(bool non_poison) const = 0; + virtual IR::StateValue + getDummyValue(bool non_poison, + smt::expr vscaleRange = smt::expr::mkVscaleMin()) const = 0; - virtual smt::expr getTypeConstraints() const = 0; + virtual smt::expr getTypeConstraints(const Function &f) const = 0; virtual smt::expr sizeVar() const; virtual smt::expr scalarSize() const; smt::expr operator==(const Type &rhs) const; @@ -102,7 +107,9 @@ class Type { virtual smt::expr toInt(State &s, smt::expr v) const; virtual IR::StateValue toInt(State &s, IR::StateValue v) const; virtual smt::expr fromInt(smt::expr v) const; - virtual IR::StateValue fromInt(IR::StateValue v) const; + virtual IR::StateValue + fromInt(IR::StateValue v, + smt::expr vscaleRange = smt::expr::mkVscaleMin()) const; // combine existing poison value in BV repr with a new boolean expr smt::expr combine_poison(const smt::expr &boolean, @@ -135,9 +142,10 @@ class Type { class VoidType final : public Type { public: VoidType() : Type("void") {} - unsigned bits() const override; - IR::StateValue getDummyValue(bool non_poison) const override; - smt::expr getTypeConstraints() const override; + unsigned + bits(smt::expr vscaleRange = smt::expr::mkVscaleMin()) const override; + IR::StateValue getDummyValue(bool non_poison, smt::expr) const override; + smt::expr getTypeConstraints(const Function &f) const override; void fixup(const smt::Model &m) override; std::pair refines(State &src_s, State &tgt_s, const StateValue &src, @@ -160,9 +168,10 @@ class IntType final : public Type { : Type(std::move(name)), bitwidth(bitwidth), defined(true) {} unsigned maxSubBitAccess() const; - unsigned bits() const override; - IR::StateValue getDummyValue(bool non_poison) const override; - smt::expr getTypeConstraints() const override; + unsigned + bits(smt::expr vscaleRange = smt::expr::mkVscaleMin()) const override; + IR::StateValue getDummyValue(bool non_poison, smt::expr) const override; + smt::expr getTypeConstraints(const Function &f) const override; smt::expr sizeVar() const override; smt::expr operator==(const IntType &rhs) const; void fixup(const smt::Model &m) override; @@ -198,7 +207,8 @@ class FloatType final : public Type { FloatType(std::string &&name) : Type(std::move(name)) {} FloatType(std::string &&name, FpType fpType) : Type(std::move(name)), fpType(fpType), defined(true) {} - unsigned bits() const override; + unsigned + bits(smt::expr vscaleRange = smt::expr::mkVscaleMin()) const override; FpType getFpType() const { return fpType; }; smt::expr getDummyFloat() const; @@ -208,8 +218,8 @@ class FloatType final : public Type { const smt::expr &b = {}, const smt::expr &c = {}) const; smt::expr isNaN(const smt::expr &v, bool signalling) const; - IR::StateValue getDummyValue(bool non_poison) const override; - smt::expr getTypeConstraints() const override; + IR::StateValue getDummyValue(bool non_poison, smt::expr) const override; + smt::expr getTypeConstraints(const Function &f) const override; smt::expr sizeVar() const override; smt::expr operator==(const FloatType &rhs) const; void fixup(const smt::Model &m) override; @@ -236,10 +246,12 @@ class PtrType final : public Type { PtrType(std::string &&name) : Type(std::move(name)) {} PtrType(unsigned addr_space); - unsigned bits() const override; - unsigned np_bits(bool fromInt) const override; - IR::StateValue getDummyValue(bool non_poison) const override; - smt::expr getTypeConstraints() const override; + unsigned + bits(smt::expr vscaleRange = smt::expr::mkVscaleMin()) const override; + unsigned np_bits(bool fromInt, smt::expr vscaleRange = smt::expr::mkUInt( + 1, var_vector_elements)) const override; + IR::StateValue getDummyValue(bool non_poison, smt::expr) const override; + smt::expr getTypeConstraints(const Function &f) const override; smt::expr sizeVar() const override; smt::expr operator==(const PtrType &rhs) const; void fixup(const smt::Model &m) override; @@ -248,7 +260,7 @@ class PtrType final : public Type { smt::expr toInt(State &s, smt::expr v) const override; IR::StateValue toInt(State &s, IR::StateValue v) const override; smt::expr fromInt(smt::expr v) const override; - IR::StateValue fromInt(IR::StateValue v) const override; + IR::StateValue fromInt(IR::StateValue v, smt::expr) const override; std::pair refines(State &src_s, State &tgt_s, const StateValue &src, const StateValue &tgt) const override; @@ -277,23 +289,31 @@ class AggregateType : public Type { std::vector &&is_padding); public: - smt::expr numElements() const; - smt::expr numElementsExcludingPadding() const; - unsigned numElementsConst() const { return elements; } - unsigned numPaddingsConst() const; - - StateValue aggregateVals(const std::vector &vals) const; + smt::expr numElements(smt::expr vscaleRange = smt::expr::mkVscaleMin()) const; + smt::expr numElementsExcludingPadding( + smt::expr vscaleRange = smt::expr::mkVscaleMin()) const; + virtual unsigned + numElementsConst(smt::expr vscaleRange = smt::expr::mkVscaleMin()) const { + return elements; + } + unsigned numPaddingsConst(smt::expr vscaleRange) const; + + StateValue + aggregateVals(const std::vector &vals, + smt::expr vscaleRange = smt::expr::mkVscaleMin()) const; StateValue extract(const StateValue &val, unsigned index, + smt::expr vscaleRange = smt::expr::mkVscaleMin(), bool fromInt = false) const; Type& getChild(unsigned index) const { return *children[index]; } bool isPadding(unsigned i) const { return is_padding[i]; } unsigned countPaddings(unsigned to_idx) const; - unsigned bits() const override; - unsigned np_bits(bool fromInt) const override; + unsigned bits(smt::expr vscaleRange) const override; + unsigned np_bits(bool fromInt, smt::expr vscaleRange) const override; // Padding is filled with poison regardless of non_poison. - IR::StateValue getDummyValue(bool non_poison) const override; - smt::expr getTypeConstraints() const override; + IR::StateValue getDummyValue(bool non_poison, + smt::expr vscaleRange) const override; + smt::expr getTypeConstraints(const Function &f) const override; smt::expr sizeVar() const override; smt::expr operator==(const AggregateType &rhs) const; void fixup(const smt::Model &m) override; @@ -306,7 +326,8 @@ class AggregateType : public Type { smt::expr toInt(State &s, smt::expr v) const override; IR::StateValue toInt(State &s, IR::StateValue v) const override; smt::expr fromInt(smt::expr v) const override; - IR::StateValue fromInt(IR::StateValue v) const override; + IR::StateValue fromInt(IR::StateValue v, + smt::expr vscaleRange) const override; std::pair refines(State &src_s, State &tgt_s, const StateValue &src, const StateValue &tgt) const override; @@ -332,18 +353,26 @@ class ArrayType final : public AggregateType { class VectorType final : public AggregateType { + bool isScalableTy = false; + public: VectorType(std::string &&name) : AggregateType(std::move(name)) {} - VectorType(std::string &&name, unsigned elements, Type &elementTy); + VectorType(std::string &&name, unsigned minElems, Type &elementTy, + bool isScalableTy = false); + virtual unsigned numElementsConst(smt::expr vscaleRange) const override; IR::StateValue extract(const IR::StateValue &vector, - const smt::expr &index) const; + const smt::expr &index, + smt::expr vscaleRange) const; IR::StateValue update(const IR::StateValue &vector, const IR::StateValue &val, - const smt::expr &idx) const; - smt::expr getTypeConstraints() const override; + const smt::expr &idx, + smt::expr vscaleRange) const; + smt::expr getTypeConstraints(const Function &f) const override; smt::expr scalarSize() const override; bool isVectorType() const override; + smt::expr operator==(const VectorType &rhs) const; + bool isScalable() const; smt::expr enforceVectorType( const std::function &enforceElem) const override; void print(std::ostream &os) const override; @@ -381,10 +410,12 @@ class SymbolicType final : public Type { // use mask of (1 << TypeNum) SymbolicType(std::string &&name, unsigned type_mask); - unsigned bits() const override; - unsigned np_bits(bool fromInt) const override; - IR::StateValue getDummyValue(bool non_poison) const override; - smt::expr getTypeConstraints() const override; + unsigned + bits(smt::expr vscaleRange = smt::expr::mkVscaleMin()) const override; + unsigned np_bits(bool fromInt, smt::expr vscaleRange = smt::expr::mkUInt( + 1, var_vector_elements)) const override; + IR::StateValue getDummyValue(bool non_poison, smt::expr) const override; + smt::expr getTypeConstraints(const Function &f) const override; smt::expr sizeVar() const override; smt::expr scalarSize() const override; smt::expr operator==(const Type &rhs) const; @@ -414,7 +445,7 @@ class SymbolicType final : public Type { smt::expr toInt(State &s, smt::expr v) const override; IR::StateValue toInt(State &s, IR::StateValue v) const override; smt::expr fromInt(smt::expr v) const override; - IR::StateValue fromInt(IR::StateValue v) const override; + IR::StateValue fromInt(IR::StateValue v, smt::expr) const override; std::pair refines(State &src_s, State &tgt_s, const StateValue &src, const StateValue &tgt) const override; @@ -428,9 +459,10 @@ class SymbolicType final : public Type { void print(std::ostream &os) const override; }; - -bool hasPtr(const Type &t); -bool isNonPtrVector(const Type &t); -unsigned minVectorElemSize(const Type &t); -uint64_t getCommonAccessSize(const Type &ty); +bool hasPtr(const Type &t, smt::expr vscaleRange = smt::expr::mkVscaleMin()); +bool isNonPtrVector(const Type &t, smt::expr vscaleRange); +unsigned minVectorElemSize(const Type &t, + smt::expr vscaleRange = smt::expr::mkVscaleMin()); +uint64_t getCommonAccessSize(const Type &ty, + smt::expr vscaleRange = smt::expr::mkVscaleMin()); } diff --git a/ir/value.cpp b/ir/value.cpp index dfb98dab5..731525ffa 100644 --- a/ir/value.cpp +++ b/ir/value.cpp @@ -22,8 +22,8 @@ void Value::rauw(const Value &what, Value &with) { UNREACHABLE(); } -expr Value::getTypeConstraints() const { - return getType().getTypeConstraints(); +expr Value::getTypeConstraints(const Function &f) const { + return getType().getTypeConstraints(f); } void Value::fixupTypes(const Model &m) { @@ -55,7 +55,7 @@ void PoisonValue::print(ostream &os) const { } StateValue PoisonValue::toSMT(State &s) const { - return getType().getDummyValue(false); + return getType().getDummyValue(false, s.getVscale()); } @@ -128,12 +128,12 @@ StateValue GlobalVariable::toSMT(State &s) const { true }; } - -static string agg_str(const Type &ty, vector &vals) { +static string agg_str(const Type &ty, vector &vals, + expr vscaleRange = expr::mkVscaleMin()) { auto agg = ty.getAsAggregateType(); string r = "{ "; unsigned j = 0; - for (unsigned i = 0, e = agg->numElementsConst(); i != e; ++i) { + for (unsigned i = 0, e = agg->numElementsConst(vscaleRange); i != e; ++i) { if (i != 0) r += ", "; if (agg->isPadding(i)) @@ -144,15 +144,17 @@ static string agg_str(const Type &ty, vector &vals) { return r + " }"; } -AggregateValue::AggregateValue(Type &type, vector &&vals) - : Value(type, agg_str(type, vals)), vals(std::move(vals)) {} +AggregateValue::AggregateValue(Type &type, vector &&vals, + expr vscaleRange) + : Value(type, agg_str(type, vals, vscaleRange)), vals(std::move(vals)) {} StateValue AggregateValue::toSMT(State &s) const { vector state_vals; for (auto *val : vals) { state_vals.emplace_back(val->toSMT(s)); } - return getType().getAsAggregateType()->aggregateVals(state_vals); + return getType().getAsAggregateType()->aggregateVals(state_vals, + s.getVscale()); } void AggregateValue::rauw(const Value &what, Value &with) { @@ -163,8 +165,8 @@ void AggregateValue::rauw(const Value &what, Value &with) { setName(agg_str(getType(), vals)); } -expr AggregateValue::getTypeConstraints() const { - expr r = Value::getTypeConstraints(); +expr AggregateValue::getTypeConstraints(const Function &f) const { + expr r = Value::getTypeConstraints(f); vector types; for (auto *val : vals) { types.emplace_back(&val->getType()); @@ -172,7 +174,7 @@ expr AggregateValue::getTypeConstraints() const { // Instr's type constraints are already generated by BasicBlock's // getTypeConstraints() continue; - r &= val->getTypeConstraints(); + r &= val->getTypeConstraints(f); } return r && getType().enforceAggregateType(&types); } @@ -214,13 +216,13 @@ string Input::getSMTName(unsigned child) const { StateValue Input::mkInput(State &s, const Type &ty, unsigned child) const { if (auto agg = ty.getAsAggregateType()) { vector vals; - for (unsigned i = 0, e = agg->numElementsConst(); i < e; ++i) { + for (unsigned i = 0, e = agg->numElementsConst(s.getVscale()); i < e; ++i) { if (agg->isPadding(i)) continue; auto name = getSMTName(child + i); vals.emplace_back(mkInput(s, agg->getChild(i), child + i)); } - return agg->aggregateVals(vals); + return agg->aggregateVals(vals, s.getVscale()); } expr val; diff --git a/ir/value.h b/ir/value.h index f00afeaee..57538a4d3 100644 --- a/ir/value.h +++ b/ir/value.h @@ -16,6 +16,7 @@ namespace smt { class Model; } namespace IR { class VoidValue; +class Function; class Value { @@ -37,7 +38,7 @@ class Value { virtual void rauw(const Value &what, Value &with); virtual void print(std::ostream &os) const = 0; virtual StateValue toSMT(State &s) const = 0; - virtual smt::expr getTypeConstraints() const; + virtual smt::expr getTypeConstraints(const Function &f) const; void fixupTypes(const smt::Model &m); static VoidValue voidVal; @@ -106,10 +107,11 @@ class GlobalVariable final : public Value { class AggregateValue final : public Value { std::vector vals; public: - AggregateValue(Type &type, std::vector &&vals); + AggregateValue(Type &type, std::vector &&vals, + smt::expr vscaleRange = smt::expr::mkVscaleMin()); auto& getVals() const { return vals; } void rauw(const Value &what, Value &with) override; - smt::expr getTypeConstraints() const override; + smt::expr getTypeConstraints(const Function &f) const override; void print(std::ostream &os) const override; StateValue toSMT(State &s) const override; }; diff --git a/llvm_util/llvm2alive.cpp b/llvm_util/llvm2alive.cpp index a029025c0..061d7d43b 100644 --- a/llvm_util/llvm2alive.cpp +++ b/llvm_util/llvm2alive.cpp @@ -1737,6 +1737,13 @@ class llvm2alive_ : public llvm::InstVisitor> { attrs.set(FnAttrs::NullPointerIsValid); break; + case llvm::Attribute::VScaleRange: { + auto l = llvmattr.getVScaleRangeMin(); + auto r = llvmattr.getVScaleRangeMax().value_or(l); + attrs.vscaleRange = {l, r}; + break; + } + default: break; } diff --git a/llvm_util/utils.cpp b/llvm_util/utils.cpp index a359edea0..928776ed8 100644 --- a/llvm_util/utils.cpp +++ b/llvm_util/utils.cpp @@ -198,7 +198,6 @@ Type* llvm_type2alive(const llvm::Type *ty) { } return cache.get(); } - // TODO: non-fixed sized vectors case llvm::Type::FixedVectorTyID: { auto &cache = type_cache[ty]; if (!cache) { @@ -212,6 +211,19 @@ Type* llvm_type2alive(const llvm::Type *ty) { } return cache.get(); } + case llvm::Type::ScalableVectorTyID: { + auto &cache = type_cache[ty]; + if (!cache) { + auto vty = cast(ty); + auto minelems = vty->getElementCount().getKnownMinValue(); + auto ety = llvm_type2alive(vty->getElementType()); + if (!ety || minelems > 1024) + return nullptr; + cache = make_unique("ty_" + to_string(type_id_counter++), + minelems, *ety, true); + } + return cache.get(); + } case llvm::Type::ArrayTyID: { auto &cache = type_cache[ty]; if (!cache) { @@ -287,8 +299,11 @@ Value* get_operand(llvm::Value *v, if (!ty) return nullptr; + smt::expr vscaleRange = + State::vscaleFromAttr(current_fn->getFnAttrs().vscaleRange); + // automatic splat of constant values - if (auto vty = dyn_cast(v->getType()); + if (auto vty = dyn_cast(v->getType()); vty && isa(v)) { llvm::Value *llvm_splat = nullptr; if (auto cnst = dyn_cast(v)) { @@ -305,8 +320,9 @@ Value* get_operand(llvm::Value *v, if (!splat) return nullptr; - vector vals(vty->getNumElements(), splat); - auto val = make_unique(*ty, std::move(vals)); + unsigned ec = ty->getAsAggregateType()->numElementsConst(vscaleRange); + vector vals(ec, splat); + auto val = make_unique(*ty, std::move(vals), vscaleRange); auto ret = val.get(); current_fn->addConstant(std::move(val)); RETURN_CACHE(ret); @@ -399,7 +415,7 @@ Value* get_operand(llvm::Value *v, { unsigned opi = 0; - for (unsigned i = 0; i < aty->numElementsConst(); ++i) { + for (unsigned i = 0; i < aty->numElementsConst(vscaleRange); ++i) { if (!aty->isPadding(i)) { if (auto op = get_operand(get_elem(opi), constexpr_conv, copy_inserter, register_fn_decl)) diff --git a/smt/expr.cpp b/smt/expr.cpp index 755dc64a9..05a41cfd7 100644 --- a/smt/expr.cpp +++ b/smt/expr.cpp @@ -643,6 +643,10 @@ unsigned expr::min_trailing_ones() const { return 0; } +unsigned expr::active_bits() const { + return bits() - min_leading_zeros(); +} + expr expr::binop_commutative(const expr &rhs, Z3_ast (*op)(Z3_context, Z3_ast, Z3_ast), expr (expr::*expr_op)(const expr &) const, diff --git a/smt/expr.h b/smt/expr.h index daae4add1..93df4a925 100644 --- a/smt/expr.h +++ b/smt/expr.h @@ -12,6 +12,9 @@ #include #include +static constexpr unsigned var_vector_elements = 16; +static constexpr unsigned var_vector_max_vscale = 16; + typedef struct _Z3_context* Z3_context; typedef struct _Z3_func_decl* Z3_decl; typedef struct _Z3_app* Z3_app; @@ -97,6 +100,9 @@ class expr { static expr mkQuadVar(const char *name); static expr mkFreshVar(const char *prefix, const expr &type); + // vscale-specific functions + static expr mkVscaleMin() { return expr::mkUInt(1, var_vector_elements); } + // return a constant value of the given type static expr some(const expr &type); @@ -160,6 +166,7 @@ class expr { // best effort; returns number of statically known bits unsigned min_leading_zeros() const; unsigned min_trailing_ones() const; + unsigned active_bits() const; expr operator+(const expr &rhs) const; expr operator-(const expr &rhs) const; diff --git a/tests/alive-tv/vector/vscale/dse-scalable-fixed-neg.srctgt.ll b/tests/alive-tv/vector/vscale/dse-scalable-fixed-neg.srctgt.ll new file mode 100644 index 000000000..9a29b9f35 --- /dev/null +++ b/tests/alive-tv/vector/vscale/dse-scalable-fixed-neg.srctgt.ll @@ -0,0 +1,15 @@ +; SKIP-IDENTITY + +define void @src(ptr %ptr) vscale_range(1, 2) { + %gep.ptr.16 = getelementptr i64, ptr %ptr, i64 16 + store <2 x i64> zeroinitializer, ptr %gep.ptr.16 + store zeroinitializer, ptr %ptr + ret void +} + +define void @tgt(ptr %ptr) vscale_range(1, 2) { + store zeroinitializer, ptr %ptr + ret void +} + +; ERROR: program doesn't type check! diff --git a/tests/alive-tv/vector/vscale/dse-scalable-fixed.srctgt.ll b/tests/alive-tv/vector/vscale/dse-scalable-fixed.srctgt.ll new file mode 100644 index 000000000..c405dea46 --- /dev/null +++ b/tests/alive-tv/vector/vscale/dse-scalable-fixed.srctgt.ll @@ -0,0 +1,15 @@ +; SKIP-IDENTITY + +define void @src(ptr %ptr) vscale_range(1, 2) { + %gep.ptr.2 = getelementptr i64, ptr %ptr, i64 2 + store <2 x i64> zeroinitializer, ptr %gep.ptr.2 + store zeroinitializer, ptr %ptr + ret void +} + +define void @tgt(ptr %ptr) vscale_range(1, 2) { + store zeroinitializer, ptr %ptr + ret void +} + +; ERROR: program doesn't type check! diff --git a/tests/alive-tv/vector/vscale/dse-scalable-scalable-neg.srctgt.ll b/tests/alive-tv/vector/vscale/dse-scalable-scalable-neg.srctgt.ll new file mode 100644 index 000000000..be4b207d0 --- /dev/null +++ b/tests/alive-tv/vector/vscale/dse-scalable-scalable-neg.srctgt.ll @@ -0,0 +1,15 @@ +; SKIP-IDENTITY + +define void @src(ptr %ptr) vscale_range(1, 4) { + %gep.ptr.8 = getelementptr i64, ptr %ptr, i64 8 + store zeroinitializer, ptr %gep.ptr.8 + store zeroinitializer, ptr %ptr + ret void +} + +define void @tgt(ptr %ptr) vscale_range(1, 4) { + store zeroinitializer, ptr %ptr + ret void +} + +; ERROR: program doesn't type check! diff --git a/tests/alive-tv/vector/vscale/dse-scalable-scalable.srctgt.ll b/tests/alive-tv/vector/vscale/dse-scalable-scalable.srctgt.ll new file mode 100644 index 000000000..95ad3d6b7 --- /dev/null +++ b/tests/alive-tv/vector/vscale/dse-scalable-scalable.srctgt.ll @@ -0,0 +1,15 @@ +; SKIP-IDENTITY + +define void @src(ptr %ptr) vscale_range(1, 4) { + %gep.ptr.2 = getelementptr i64, ptr %ptr, i64 2 + store zeroinitializer, ptr %gep.ptr.2 + store zeroinitializer, ptr %ptr + ret void +} + +define void @tgt(ptr %ptr) vscale_range(1, 4) { + store zeroinitializer, ptr %ptr + ret void +} + +; ERROR: program doesn't type check! diff --git a/tests/alive-tv/vector/vscale/inbounds-poison.srctgt.ll b/tests/alive-tv/vector/vscale/inbounds-poison.srctgt.ll new file mode 100644 index 000000000..35357765b --- /dev/null +++ b/tests/alive-tv/vector/vscale/inbounds-poison.srctgt.ll @@ -0,0 +1,12 @@ +; SKIP-IDENTITY + +define @src( %a) vscale_range(2, 4) { + %v = insertelement %a, i8 -1, i64 2 + ret %v +} + +define @tgt( %a) vscale_range(2, 4) { + ret poison +} + +; ERROR: program doesn't type check! diff --git a/tests/alive-tv/vector/vscale/insert-extract-constvscale.srctgt.ll b/tests/alive-tv/vector/vscale/insert-extract-constvscale.srctgt.ll new file mode 100644 index 000000000..761493d59 --- /dev/null +++ b/tests/alive-tv/vector/vscale/insert-extract-constvscale.srctgt.ll @@ -0,0 +1,9 @@ +define i8 @src( %a) vscale_range(4, 4) { + %v = insertelement %a, i8 -1, i64 2 + %r = extractelement %v, i64 2 + ret i8 %r +} + +define i8 @tgt( %a) vscale_range(4, 4) { + ret i8 -1 +} diff --git a/tests/alive-tv/vector/vscale/insert-extract.srctgt.ll b/tests/alive-tv/vector/vscale/insert-extract.srctgt.ll new file mode 100644 index 000000000..4b9301679 --- /dev/null +++ b/tests/alive-tv/vector/vscale/insert-extract.srctgt.ll @@ -0,0 +1,13 @@ +; SKIP-IDENTITY + +define i8 @src( %a) vscale_range(2, 4) { + %v = insertelement %a, i8 -1, i64 2 + %r = extractelement %v, i64 2 + ret i8 %r +} + +define i8 @tgt( %a) vscale_range(2, 4) { + ret i8 -1 +} + +; ERROR: program doesn't type check! diff --git a/tests/alive-tv/vector/vscale/out-of-bounds-poison.srctgt.ll b/tests/alive-tv/vector/vscale/out-of-bounds-poison.srctgt.ll new file mode 100644 index 000000000..7cf8207a0 --- /dev/null +++ b/tests/alive-tv/vector/vscale/out-of-bounds-poison.srctgt.ll @@ -0,0 +1,12 @@ +; SKIP-IDENTITY + +define @src( %a) vscale_range(1, 2) { + %v = insertelement %a, i8 -2, i64 3 + ret %v +} + +define @tgt( %a) vscale_range(1, 2) { + ret poison +} + +; ERROR: program doesn't type check! diff --git a/tests/alive-tv/vector/vscale/poison-constvscale.srctgt.ll b/tests/alive-tv/vector/vscale/poison-constvscale.srctgt.ll new file mode 100644 index 000000000..af5547da4 --- /dev/null +++ b/tests/alive-tv/vector/vscale/poison-constvscale.srctgt.ll @@ -0,0 +1,14 @@ +define i32 @src(i32 %a) vscale_range(4, 4) { + %poison = add nsw i32 2147483647, 100 + %v = insertelement poison, i32 %a, i64 0 + %v2 = insertelement %v, i32 %poison, i64 1 + %w = extractelement %v2, i64 0 + ret i32 %w +} + +define i32 @tgt(i32 %a) vscale_range(4, 4) { + %poison = add nsw i32 2147483647, 100 + ret i32 %poison +} + +; ERROR: Target is more poisonous than source diff --git a/tests/alive-tv/vector/vscale/rem-constvscale.srctgt.ll b/tests/alive-tv/vector/vscale/rem-constvscale.srctgt.ll new file mode 100644 index 000000000..60fe126da --- /dev/null +++ b/tests/alive-tv/vector/vscale/rem-constvscale.srctgt.ll @@ -0,0 +1,16 @@ +; SKIP-IDENTITY + +define @src( %x) vscale_range(2, 2) { + %rem.i = srem %x, splat(i8 2) + %cmp.i = icmp slt %rem.i, zeroinitializer + %add.i = select %cmp.i, splat(i8 2), zeroinitializer + ret %add.i +} + +define @tgt( %x) vscale_range(2, 2) { + %rem.i = srem %x, splat(i8 2) + %tmp1 = and %rem.i, splat(i8 2) + ret %tmp1 +} + +; ERROR: program doesn't type check! diff --git a/tests/alive-tv/vector/vscale/rem.srctgt.ll b/tests/alive-tv/vector/vscale/rem.srctgt.ll new file mode 100644 index 000000000..c481cbe3f --- /dev/null +++ b/tests/alive-tv/vector/vscale/rem.srctgt.ll @@ -0,0 +1,16 @@ +; SKIP-IDENTITY + +define @src( %x) vscale_range(1, 2) { + %rem.i = srem %x, splat(i8 2) + %cmp.i = icmp slt %rem.i, zeroinitializer + %add.i = select %cmp.i, splat(i8 2), zeroinitializer + ret %add.i +} + +define @tgt( %x) vscale_range(1, 2) { + %rem.i = srem %x, splat(i8 2) + %tmp1 = and %rem.i, splat(i8 2) + ret %tmp1 +} + +; ERROR: program doesn't type check! diff --git a/tests/alive-tv/vector/vscale/typecheck-missing-vscale-range.srctgt.ll b/tests/alive-tv/vector/vscale/typecheck-missing-vscale-range.srctgt.ll new file mode 100644 index 000000000..3c70ae8dd --- /dev/null +++ b/tests/alive-tv/vector/vscale/typecheck-missing-vscale-range.srctgt.ll @@ -0,0 +1,12 @@ +; SKIP-IDENTITY + +define @src( %a) { + %v = insertelement %a, i8 -2, i64 3 + ret %v +} + +define @tgt( %a) { + ret poison +} + +; ERROR: program doesn't type check! diff --git a/tests/alive-tv/vector/vscale/typecheck-scalable-non-scalable.srctgt.ll b/tests/alive-tv/vector/vscale/typecheck-scalable-non-scalable.srctgt.ll new file mode 100644 index 000000000..e700e7fdd --- /dev/null +++ b/tests/alive-tv/vector/vscale/typecheck-scalable-non-scalable.srctgt.ll @@ -0,0 +1,12 @@ +; SKIP-IDENTITY + +define @src( %a) vscale_range(1, 2) { + %v = insertelement %a, i8 -2, i64 3 + ret %v +} + +define <1 x i8> @tgt(<1 x i8> %a) vscale_range(1, 2) { + ret <1 x i8> poison +} + +; ERROR: program doesn't type check! diff --git a/tools/transform.cpp b/tools/transform.cpp index ad45c6511..e225119f8 100644 --- a/tools/transform.cpp +++ b/tools/transform.cpp @@ -97,11 +97,11 @@ void tools::print_model_val(ostream &os, State &st, const Model &m, os << (type.isStructType() ? "{ " : "< "); auto agg = type.getAsAggregateType(); - for (unsigned i = 0, e = agg->numElementsConst(); i < e; ++i) { + for (unsigned i = 0, e = agg->numElementsConst(st.getVscale()); i < e; ++i) { if (i != 0) os << ", "; tools::print_model_val(os, st, m, var, agg->getChild(i), - agg->extract(val, i), child + i); + agg->extract(val, i, st.getVscale()), child + i); } os << (type.isStructType() ? " }" : " >"); } @@ -566,7 +566,7 @@ check_refinement(Errors &errs, const Transform &t, State &src_state, errs.add("Precondition is always false", false); return; } - + vector> repls; auto vars_pre = pre_src.vars(); for (auto &v : qvars) { @@ -1497,7 +1497,7 @@ TypingAssignments TransformVerify::getTypings() const { auto c = t.src.getTypeConstraints() && t.tgt.getTypeConstraints(); if (t.precondition) - c &= t.precondition->getTypeConstraints(); + c &= t.precondition->getTypeConstraints(t.src); // return type c &= t.src.getType() == t.tgt.getType();