Skip to content

Commit

Permalink
type: add preliminiary support for singular vscale
Browse files Browse the repository at this point in the history
Change llvm_util to correctly fill aggregates and scalable vector
splats, based on the vscale_range attribute on the function, and add a
virtual numElementsConst() override, to operate on the maximum value in
the vscale range. Not considering all possible vscale values is
obviously a correctness issue, so ensure that we only pass the
type-check when the vscale range has a singular value.

In practice, the entire codebase needs to be taught about scalable
vectors, before all programs with a singular vscale value are supported.
The rem-costvscale.srctgt.ll test fails to type-check, for instance, due
to ShuffleVector using a non-scalable mask. These instances are left as
a todo for follow-up patches.
  • Loading branch information
artagnon committed Dec 9, 2024
1 parent b2f26b6 commit 663e155
Show file tree
Hide file tree
Showing 7 changed files with 27 additions and 16 deletions.
15 changes: 11 additions & 4 deletions ir/type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1098,8 +1098,15 @@ VectorType::VectorType(string &&name, unsigned minElems, Type &elementTy,
this->isScalableTy = isScalableTy;
this->elements = minElems;
defined = true;
children.resize(elements, &elementTy);
is_padding.resize(elements, false);
unsigned scaleFactor = isScalableTy ? var_vector_max_vscale : 1;
children.resize(elements * scaleFactor, &elementTy);

Check failure

Code scanning / CodeQL

Multiplication result converted to larger type High

Multiplication result may overflow 'unsigned int' before it is converted to 'size_type'.
is_padding.resize(elements * scaleFactor, false);

Check failure

Code scanning / CodeQL

Multiplication result converted to larger type High

Multiplication result may overflow 'unsigned int' before it is converted to 'size_type'.
}

unsigned VectorType::numElementsConst(expr vscaleRange) const {
unsigned scaleFactor =
isScalable() ? 1 << (vscaleRange.active_bits() - 1) : 1;
return elements * scaleFactor;
}

StateValue VectorType::extract(const StateValue &vector,
Expand Down Expand Up @@ -1176,8 +1183,8 @@ expr VectorType::getTypeConstraints(const Function &f) const {
r &= elems.ugt(i).implies(elementTy == *children[i]);
}

// TODO: remove once scalable vectors are supported.
r &= !isScalable();
// TODO: remove once scalable vectors are fully supported.
r &= vscaleRange.isPowerOf2();

return r;
}
Expand Down
3 changes: 2 additions & 1 deletion ir/type.h
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ class AggregateType : public Type {
smt::expr numElements(smt::expr vscaleRange = smt::expr::mkVscaleMin()) const;
smt::expr numElementsExcludingPadding(
smt::expr vscaleRange = smt::expr::mkVscaleMin()) const;
unsigned
virtual unsigned
numElementsConst(smt::expr vscaleRange = smt::expr::mkVscaleMin()) const {
return elements;
}
Expand Down Expand Up @@ -360,6 +360,7 @@ class VectorType final : public AggregateType {
VectorType(std::string &&name, unsigned minElems, Type &elementTy,
bool isScalableTy = false);

virtual unsigned numElementsConst(smt::expr vscaleRange) const override;
IR::StateValue extract(const IR::StateValue &vector,
const smt::expr &index,
smt::expr vscaleRange) const;
Expand Down
12 changes: 8 additions & 4 deletions llvm_util/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -299,8 +299,11 @@ Value* get_operand(llvm::Value *v,
if (!ty)
return nullptr;

smt::expr vscaleRange =
State::vscaleFromAttr(current_fn->getFnAttrs().vscaleRange);

// automatic splat of constant values
if (auto vty = dyn_cast<llvm::FixedVectorType>(v->getType());
if (auto vty = dyn_cast<llvm::VectorType>(v->getType());
vty && isa<llvm::ConstantInt, llvm::ConstantFP>(v)) {
llvm::Value *llvm_splat = nullptr;
if (auto cnst = dyn_cast<llvm::ConstantInt>(v)) {
Expand All @@ -317,8 +320,9 @@ Value* get_operand(llvm::Value *v,
if (!splat)
return nullptr;

vector<Value*> vals(vty->getNumElements(), splat);
auto val = make_unique<AggregateValue>(*ty, std::move(vals));
unsigned ec = ty->getAsAggregateType()->numElementsConst(vscaleRange);
vector<Value *> vals(ec, splat);
auto val = make_unique<AggregateValue>(*ty, std::move(vals), vscaleRange);
auto ret = val.get();
current_fn->addConstant(std::move(val));
RETURN_CACHE(ret);
Expand Down Expand Up @@ -411,7 +415,7 @@ Value* get_operand(llvm::Value *v,
{
unsigned opi = 0;

for (unsigned i = 0; i < aty->numElementsConst(); ++i) {
for (unsigned i = 0; i < aty->numElementsConst(vscaleRange); ++i) {
if (!aty->isPadding(i)) {
if (auto op = get_operand(get_elem(opi), constexpr_conv, copy_inserter,
register_fn_decl))
Expand Down
4 changes: 4 additions & 0 deletions smt/expr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -643,6 +643,10 @@ unsigned expr::min_trailing_ones() const {
return 0;
}

unsigned expr::active_bits() const {
return bits() - min_leading_zeros();
}

expr expr::binop_commutative(const expr &rhs,
Z3_ast (*op)(Z3_context, Z3_ast, Z3_ast),
expr (expr::*expr_op)(const expr &) const,
Expand Down
1 change: 1 addition & 0 deletions smt/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ class expr {
// best effort; returns number of statically known bits
unsigned min_leading_zeros() const;
unsigned min_trailing_ones() const;
unsigned active_bits() const;

expr operator+(const expr &rhs) const;
expr operator-(const expr &rhs) const;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
; SKIP-IDENTITY

define i8 @src(<vscale x 1 x i8> %a) vscale_range(4, 4) {
%v = insertelement <vscale x 1 x i8> %a, i8 -1, i64 2
%r = extractelement <vscale x 1 x i8> %v, i64 2
Expand All @@ -9,5 +7,3 @@ define i8 @src(<vscale x 1 x i8> %a) vscale_range(4, 4) {
define i8 @tgt(<vscale x 1 x i8> %a) vscale_range(4, 4) {
ret i8 -1
}

; ERROR: program doesn't type check!
4 changes: 1 addition & 3 deletions tests/alive-tv/vector/vscale/poison-constvscale.srctgt.ll
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
; SKIP-IDENTITY

define i32 @src(i32 %a) vscale_range(4, 4) {
%poison = add nsw i32 2147483647, 100
%v = insertelement <vscale x 2 x i32> poison, i32 %a, i64 0
Expand All @@ -13,4 +11,4 @@ define i32 @tgt(i32 %a) vscale_range(4, 4) {
ret i32 %poison
}

; ERROR: program doesn't type check!
; ERROR: Target is more poisonous than source

0 comments on commit 663e155

Please sign in to comment.