Skip to content

Commit

Permalink
fix #937: infer alignment of some memory ops during preprocessing
Browse files Browse the repository at this point in the history
This allows more efficient memory encodings
  • Loading branch information
nunoplopes committed Sep 24, 2023
1 parent a6205b6 commit 612737a
Show file tree
Hide file tree
Showing 5 changed files with 173 additions and 1 deletion.
29 changes: 28 additions & 1 deletion ir/instr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2093,7 +2093,8 @@ uint64_t FnCall::getAlign() const {
align = getIntOr(*arg, 1);

return max(align,
attrs.has(FnAttrs::Align) ? attrs.align : heap_block_alignment);
attrs.has(FnAttrs::Align) ? attrs.align :
(attrs.isAlloc() ? heap_block_alignment : 1));
}

uint64_t FnCall::getMaxAccessSize() const {
Expand Down Expand Up @@ -3618,6 +3619,23 @@ uint64_t GEP::getMaxGEPOffset() const {
return off;
}

optional<uint64_t> GEP::getExactOffset() const {
uint64_t off = 0;
for (auto &[mul, v] : getIdxs()) {
if (mul == 0)
continue;
if (mul >= INT64_MAX)
return {};

if (auto n = getInt(*v)) {
off = add_saturate(off, abs((int64_t)mul * *n));
continue;
}
return {};
}
return off;
}

vector<Value*> GEP::operands() const {
vector<Value*> v = { ptr };
for (auto &[sz, idx] : idxs) {
Expand Down Expand Up @@ -3763,6 +3781,15 @@ uint64_t PtrMask::getMaxGEPOffset() const {
return UINT64_MAX;
}

optional<uint64_t> PtrMask::getExactAlign() const {
if (auto n = getInt(*mask)) {
uint64_t align = -*n;
if (is_power2(align))
return align;
}
return {};
}

vector<Value*> PtrMask::operands() const {
return { ptr, mask };
}
Expand Down
13 changes: 13 additions & 0 deletions ir/instr.h
Original file line number Diff line number Diff line change
Expand Up @@ -692,6 +692,7 @@ class Alloc final : public MemInstr {

Value& getSize() const { return *size; }
Value* getMul() const { return mul; }
uint64_t getAlign() const { return align; }
bool initDead() const { return initially_dead; }
void markAsInitiallyDead() { initially_dead = true; }

Expand Down Expand Up @@ -766,6 +767,7 @@ class GEP final : public MemInstr {
Value& getPtr() const { return *ptr; }
auto& getIdxs() const { return idxs; }
bool isInBounds() const { return inbounds; }
std::optional<uint64_t> getExactOffset() const;

std::pair<uint64_t, uint64_t> getMaxAllocSize() const override;
uint64_t getMaxAccessSize() const override;
Expand All @@ -791,6 +793,9 @@ class PtrMask final : public MemInstr {
PtrMask(Type &type, std::string &&name, Value &ptr, Value &mask)
: MemInstr(type, std::move(name)), ptr(&ptr), mask(&mask) {}

Value& getPtr() const { return *ptr; }
std::optional<uint64_t> getExactAlign() const;

std::pair<uint64_t, uint64_t> getMaxAllocSize() const override;
uint64_t getMaxAccessSize() const override;
uint64_t getMaxGEPOffset() const override;
Expand All @@ -817,6 +822,7 @@ class Load final : public MemInstr {

Value& getPtr() const { return *ptr; }
uint64_t getAlign() const { return align; }
void setAlign(uint64_t align) { this->align = align; }

std::pair<uint64_t, uint64_t> getMaxAllocSize() const override;
uint64_t getMaxAccessSize() const override;
Expand Down Expand Up @@ -844,6 +850,7 @@ class Store final : public MemInstr {
Value& getValue() const { return *val; }
Value& getPtr() const { return *ptr; }
uint64_t getAlign() const { return align; }
void setAlign(uint64_t align) { this->align = align; }

std::pair<uint64_t, uint64_t> getMaxAllocSize() const override;
uint64_t getMaxAccessSize() const override;
Expand All @@ -870,8 +877,10 @@ class Memset final : public MemInstr {
: MemInstr(Type::voidTy, "memset"), ptr(&ptr), val(&val), bytes(&bytes),
align(align) {}

Value& getPtr() const { return *ptr; }
Value& getBytes() const { return *bytes; }
uint64_t getAlign() const { return align; }
void setAlign(uint64_t align) { this->align = align; }

std::pair<uint64_t, uint64_t> getMaxAllocSize() const override;
uint64_t getMaxAccessSize() const override;
Expand Down Expand Up @@ -943,9 +952,13 @@ class Memcpy final : public MemInstr {
: MemInstr(Type::voidTy, "memcpy"), dst(&dst), src(&src), bytes(&bytes),
align_dst(align_dst), align_src(align_src), move(move) {}

Value& getSrc() const { return *src; }
Value& getDst() const { return *dst; }
Value& getBytes() const { return *bytes; }
uint64_t getSrcAlign() const { return align_src; }
uint64_t getDstAlign() const { return align_dst; }
void setSrcAlign(uint64_t align) { align_src = align; }
void setDstAlign(uint64_t align) { align_dst = align; }

std::pair<uint64_t, uint64_t> getMaxAllocSize() const override;
uint64_t getMaxAccessSize() const override;
Expand Down
21 changes: 21 additions & 0 deletions tests/alive-tv/opt-memory/align-infer.srctgt.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
; TEST-ARGS: -dbg

define void @src() {
%p = alloca i64, align 4
store i32 0, ptr %p, align 1
%p2 = call ptr @llvm.ptrmask(ptr %p, i64 -8)
store i32 0, ptr %p2, align 4
ret void
}

define void @tgt() {
%p = alloca i64, align 4
store i32 0, ptr %p, align 4
%p2 = call ptr @llvm.ptrmask(ptr %p, i64 -8)
store i32 0, ptr %p2, align 8
ret void
}

declare ptr @llvm.ptrmask(ptr, i64)

; CHECK: bits_byte: 32
37 changes: 37 additions & 0 deletions tests/alive-tv/opt-memory/align-phi-infer.ident.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
; TEST-ARGS: -dbg

define void @src() {
entry:
%p = alloca i64, align 4
br label %loop

loop:
%p2 = phi ptr [ %p, %entry ], [ %p3, %loop ]
store i16 0, ptr %p2, align 1
%p3 = getelementptr i8, ptr %p2, i64 2
%cond = call i1 @cond()
br i1 %cond, label %loop, label %exit

exit:
ret void
}

define void @tgt() {
entry:
%p = alloca i64, align 4
br label %loop

loop:
%p2 = phi ptr [ %p, %entry ], [ %p3, %loop ]
store i16 0, ptr %p2, align 2
%p3 = getelementptr i8, ptr %p2, i64 2
%cond = call i1 @cond()
br i1 %cond, label %loop, label %exit

exit:
ret void
}

declare i1 @cond()

; CHECK: bits_byte: 16
74 changes: 74 additions & 0 deletions tools/transform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1645,6 +1645,80 @@ void Transform::preprocess() {
} while (changed);
}

// infer alignment of memory operations
unordered_map<const Value*, uint64_t> aligns;
unordered_set<const Value*> worklist;
for (auto fn : { &src, &tgt }) {
for (auto &in0 : fn->getInputs()) {
auto *in = dynamic_cast<const Input*>(&in0);
if (!in || !in->getType().isPtrType())
continue;
aligns.emplace(in, in->getAttributes().align);
}

for (auto &i : fn->instrs()) {
worklist.emplace(&i);
}

auto users = fn->getUsers();

do {
auto I = worklist.begin();
auto *i = *I;
worklist.erase(I);

uint64_t align = 0;
if (auto *alloc = dynamic_cast<const Alloc*>(i)) {
align = alloc->getAlign();
} else if (auto *call = dynamic_cast<const FnCall*>(i)) {
align = call->getAlign();
} else if (auto *mask = dynamic_cast<const PtrMask*>(i)) {
if (auto a = mask->getExactAlign()) {
align = max(*a, aligns[&mask->getPtr()]);
} else {
continue;
}
} else if (auto *gep = dynamic_cast<const GEP*>(i)) {
auto off = gep->getExactOffset();
if (!off || !is_power2(*off))
continue;
align = min(aligns[&gep->getPtr()], *off);
} else if (auto *phi = dynamic_cast<const Phi*>(i)) {
// optimistic handling of phis: unreachable predecessors don't
// contribute to the result. This is revisited once they become reach
for (auto &[val, bb] : phi->getValues()) {
if (auto phi_align = aligns[val])
align = align ? min(align, phi_align) : phi_align;
}
} else {
continue;
}
if (align != aligns[i]) {
aligns[i] = align;
for (auto [user, BB] : users[i]) {
worklist.emplace(user);
}
}
} while (!worklist.empty());

#define HANDLE(class, ptr, get, set) \
if (auto *obj = dynamic_cast<class*>(i)) { \
uint64_t newal = aligns[&obj->ptr()]; \
if (newal > obj->get()) \
obj->set(newal); \
}

for (auto &i0 : fn->instrs()) {
auto i = const_cast<Instr*>(&i0);
HANDLE(Store, getPtr, getAlign, setAlign)
HANDLE(Load, getPtr, getAlign, setAlign)
HANDLE(Memset, getPtr, getAlign, setAlign)
HANDLE(Memcpy, getSrc, getSrcAlign, setSrcAlign)
HANDLE(Memcpy, getDst, getDstAlign, setDstAlign)
}
aligns.clear();
}

// bits_program_pointer is used by unroll. Initialize it in advance
initBitsProgramPointer(*this);

Expand Down

0 comments on commit 612737a

Please sign in to comment.