Skip to content

Commit

Permalink
Merge pull request #91 from dop-amin/flexible_loops
Browse files Browse the repository at this point in the history
More general structure for loop parsing
  • Loading branch information
hanno-becker authored Nov 29, 2024
2 parents 63ce18e + 899328b commit 90a78c0
Show file tree
Hide file tree
Showing 8 changed files with 227 additions and 150 deletions.
35 changes: 35 additions & 0 deletions example.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,37 @@ def core(self, slothy):
slothy.config.inputs_are_outputs = True
slothy.optimize_loop("start")

class LoopLe(Example):
def __init__(self, var="", arch=Arch_Armv81M, target=Target_CortexM55r1):
name = "loop_le"
infile = name

if var != "":
name += f"_{var}"
infile += f"_{var}"
name += f"_{target_label_dict[target]}"

super().__init__(infile, name, rename=True, arch=arch, target=target)

def core(self,slothy):
slothy.config.variable_size=True
slothy.optimize_loop("start")

class AArch64LoopSubs(Example):
def __init__(self, var="", arch=AArch64_Neon, target=Target_CortexA55):
name = "aarch64_loop_subs"
infile = name

if var != "":
name += f"_{var}"
infile += f"_{var}"
name += f"_{target_label_dict[target]}"

super().__init__(infile, name, rename=True, arch=arch, target=target)

def core(self,slothy):
slothy.config.variable_size=True
slothy.optimize_loop("start")

class CRT(Example):
def __init__(self):
Expand Down Expand Up @@ -1394,6 +1425,10 @@ def main():
AArch64Example2(),
AArch64Example2(target=Target_CortexA72),

# Loop examples
AArch64LoopSubs(),
LoopLe(),

CRT(),

ntt_n256_l6_s32("bar"),
Expand Down
9 changes: 9 additions & 0 deletions examples/naive/aarch64/aarch64_loop_subs.s
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
count .req x2

mov count, #16
start:

nop

subs count, count, #1
cbnz count, start
4 changes: 4 additions & 0 deletions examples/naive/loop_le.s
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
mov lr, #16
start:
nop
le lr, start
20 changes: 12 additions & 8 deletions slothy/core/slothy.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,12 +367,11 @@ def fusion_loop(self, loop_lbl):
"""Run fusion callbacks on loop body"""
logger = self.logger.getChild(f"ssa_loop_{loop_lbl}")

pre , body, post, _, other_data = \
pre , body, post, _, other_data, loop = \
self.arch.Loop.extract(self.source, loop_lbl)
(loop_cnt, _, _) = other_data
loop_cnt = other_data['cnt']
indentation = AsmHelper.find_indentation(body)

loop = self.arch.Loop(lbl_start=loop_lbl)
body_ssa = SourceLine.read_multiline(loop.start(loop_cnt)) + \
SourceLine.apply_indentation(self._fusion_core(pre, body, logger), indentation) + \
SourceLine.read_multiline(loop.end(other_data))
Expand All @@ -394,13 +393,15 @@ def fusion_loop(self, loop_lbl):
assert SourceLine.is_source(self.source)

def optimize_loop(self, loop_lbl, postamble_label=None):
"""Optimize the loop starting at a given label"""
"""Optimize the loop starting at a given label
The postamble_label marks the end of the loop kernel.
"""

logger = self.logger.getChild(loop_lbl)

early, body, late, _, other_data = \
early, body, late, _, other_data, loop = \
self.arch.Loop.extract(self.source, loop_lbl)
(loop_cnt, _, _) = other_data
loop_cnt = other_data['cnt']

# Check if the body has a dominant indentation
indentation = AsmHelper.find_indentation(body)
Expand Down Expand Up @@ -464,7 +465,6 @@ def loop_lbl_iter(i):
for i in range(1, num_exceptional):
optimized_code += indented(self.arch.Branch.if_equal(loop_cnt, i, loop_lbl_iter(i)))

loop = self.arch.Loop(lbl_start=loop_lbl)
optimized_code += indented(preamble_code)

if self.config.sw_pipelining.unknown_iteration_count:
Expand All @@ -479,7 +479,11 @@ def loop_lbl_iter(i):
indentation=self.config.indentation,
fixup=num_exceptional,
unroll=self.config.sw_pipelining.unroll,
jump_if_empty=jump_if_empty))
jump_if_empty=jump_if_empty,
preamble_code=preamble_code,
body_code=kernel_code,
postamble_code=postamble_code,
register_aliases=c.register_aliases))
optimized_code += indented(kernel_code)
optimized_code += SourceLine.read_multiline(loop.end(other_data,
indentation=self.config.indentation))
Expand Down
106 changes: 106 additions & 0 deletions slothy/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@
import re
import subprocess
import logging
from abc import ABC, abstractmethod

from slothy.targets.common import *

class SourceLine:
"""Representation of a single line of source code"""
Expand Down Expand Up @@ -1089,3 +1092,106 @@ def forward_to_file(self, log_label, filename, lvl=logging.DEBUG):
h.setLevel(lvl)
l.addHandler(h)
self.forward(l)


class Loop(ABC):
def __init__(self, lbl_start="1", lbl_end="2", loop_init="lr"):
self.lbl_start = lbl_start
self.lbl_end = lbl_end
self.loop_init = loop_init
self.additional_data = {}

@abstractmethod
def start(self, loop_cnt, indentation=0, fixup=0, unroll=1, jump_if_empty=None):
"""Emit starting instruction(s) and jump label for loop"""
pass

@abstractmethod
def end(self, other, indentation=0):
"""Emit compare-and-branch at the end of the loop"""
pass

def _extract(self, source, lbl):
"""Locate a loop with start label `lbl` in `source`.```"""
assert isinstance(source, list)

# additional_data will be assigned according to the capture groups from
# loop_end_regexp.
pre = []
body = []
post = []
# candidate lines for the end of the loop
loop_end_candidates = []
loop_lbl_regexp_txt = self.lbl_regex
loop_lbl_regexp = re.compile(loop_lbl_regexp_txt)

# end_regex shall contain group cnt as the counter variable
loop_end_regexp_txt = self.end_regex
loop_end_regexp = [re.compile(txt) for txt in loop_end_regexp_txt]
lines = iter(source)
l = None
keep = False
state = 0 # 0: haven't found loop yet, 1: extracting loop, 2: after loop
loop_end_ctr = 0
while True:
if not keep:
l = next(lines, None)
keep = False
if l is None:
break
l_str = l.text
assert isinstance(l, str) is False
if state == 0:
p = loop_lbl_regexp.match(l_str)
if p is not None and p.group("label") == lbl:
l = l.copy().set_text(p.group("remainder"))
keep = True
state = 1
else:
pre.append(l)
continue
if state == 1:
p = loop_end_regexp[loop_end_ctr].match(l_str)
if p is not None:
# Case: We may have encountered part of the loop end
# collect all named groups
self.additional_data = self.additional_data | p.groupdict()
loop_end_ctr += 1
loop_end_candidates.append(l)
if loop_end_ctr == len(loop_end_regexp):
state = 2
continue
elif loop_end_ctr > 0 and l_str != "":
# Case: The sequence of loop end candidates was interrupted
# i.e., we found a false-positive or this is not a proper loop

# The loop end candidates are not part of the loop, meaning
# they belonged to the body
body += loop_end_candidates
self.additional_data = {}
loop_end_ctr = 0
loop_end_candidates = []
body.append(l)
continue
if state == 2:
loop_end_candidates = []
post.append(l)
continue
if state < 2:
raise FatalParsingException(f"Couldn't identify loop {lbl}")
return pre, body, post, lbl, self.additional_data

@staticmethod
def extract(source, lbl):
for loop_type in Loop.__subclasses__():
try:
l = loop_type(lbl)
# concatenate the extracted loop with an instance of the
# identified loop_type, (l,) creates a tuple with one element to
# merge with the tuple retuned by _extract
return l._extract(source, lbl) + (l,)
except FatalParsingException:
logging.debug("Parsing loop type '%s'failed", loop_type)
pass

raise FatalParsingException(f"Couldn't identify loop {lbl}")
121 changes: 29 additions & 92 deletions slothy/targets/aarch64/aarch64_neon.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ class which generates instruction parsers and writers from instruction templates

from sympy import simplify

from slothy.targets.common import *
from slothy.helper import Loop

arch_name = "Arm_AArch64"
llvm_mca_arch = "aarch64"

Expand Down Expand Up @@ -169,119 +172,53 @@ def unconditional(lbl):
"""Emit unconditional branch"""
yield f"b {lbl}"

class Loop:
"""Helper functions for parsing and writing simple loops in AArch64

TODO: Generalize; current implementation too specific about shape of loop"""

def __init__(self, lbl_start="1", lbl_end="2", loop_init="lr"):
self.lbl_start = lbl_start
self.lbl_end = lbl_end
self.loop_init = loop_init
class SubsLoop(Loop):
"""
Loop ending in a flag setting subtraction and a branch.

def start(self, loop_cnt, indentation=0, fixup=0, unroll=1, jump_if_empty=None):
Example:
```
loop_lbl:
{code}
sub[s] <cnt>, <cnt>, #<imm>
(cbnz|bnz|bne) <cnt>, loop_lbl
```
where cnt is the loop counter in lr.
"""
def __init__(self, lbl="lbl", lbl_start="1", lbl_end="2", loop_init="lr") -> None:
super().__init__(lbl_start=lbl_start, lbl_end=lbl_end, loop_init=loop_init)
# The group naming in the regex should be consistent; give same group
# names to the same registers
self.lbl_regex = r"^\s*(?P<label>\w+)\s*:(?P<remainder>.*)$"
self.end_regex = (r"^\s*sub[s]?\s+(?P<cnt>\w+),\s*(?P<reg1>\w+),\s*#(?P<imm>\d+)",
rf"^\s*(cbnz|bnz|bne)\s+(?P<cnt>\w+),\s*{lbl}")

def start(self, loop_cnt, indentation=0, fixup=0, unroll=1, jump_if_empty=None, preamble_code=None, body_code=None, postamble_code=None, register_aliases=None):
"""Emit starting instruction(s) and jump label for loop"""
indent = ' ' * indentation
if unroll > 1:
assert unroll in [1,2,4,8,16,32]
yield f"{indent}lsr {loop_cnt}, {loop_cnt}, #{int(math.log2(unroll))}"
if fixup != 0:
# In case the immediate is >1, we need to scale the fixup. This
# allows for loops that do not use an increment of 1
fixup *= self.additional_data['imm']
yield f"{indent}sub {loop_cnt}, {loop_cnt}, #{fixup}"
if jump_if_empty is not None:
yield f"cbz {loop_cnt}, {jump_if_empty}"
yield f"{self.lbl_start}:"

def end(self, other, indentation=0):
"""Emit compare-and-branch at the end of the loop"""
(reg0, reg1, imm) = other
indent = ' ' * indentation
lbl_start = self.lbl_start
if lbl_start.isdigit():
lbl_start += "b"

yield f"{indent}sub {reg0}, {reg1}, {imm}"
yield f"{indent}cbnz {reg0}, {lbl_start}"

@staticmethod
def extract(source, lbl):
"""Locate a loop with start label `lbl` in `source`.

We currently only support the following loop forms:
yield f"{indent}sub {other['cnt']}, {other['cnt']}, {other['imm']}"
yield f"{indent}cbnz {other['cnt']}, {lbl_start}"

```
loop_lbl:
{code}
sub[s] <cnt>, <cnt>, #1
(cbnz|bnz|bne) <cnt>, loop_lbl
```

"""
assert isinstance(source, list)

pre = []
body = []
post = []
loop_lbl_regexp_txt = r"^\s*(?P<label>\w+)\s*:(?P<remainder>.*)$"
loop_lbl_regexp = re.compile(loop_lbl_regexp_txt)

# TODO: Allow other forms of looping

loop_end_regexp_txt = (r"^\s*sub[s]?\s+(?P<reg0>\w+),\s*(?P<reg1>\w+),\s*(?P<imm>#1)",
rf"^\s*(cbnz|bnz|bne)\s+(?P<reg0>\w+),\s*{lbl}")
loop_end_regexp = [re.compile(txt) for txt in loop_end_regexp_txt]
lines = iter(source)
l = None
keep = False
state = 0 # 0: haven't found loop yet, 1: extracting loop, 2: after loop
while True:
if not keep:
l = next(lines, None)
keep = False
if l is None:
break
l_str = l.text
assert isinstance(l, str) is False
if state == 0:
p = loop_lbl_regexp.match(l_str)
if p is not None and p.group("label") == lbl:
l = l.copy().set_text(p.group("remainder"))
keep = True
state = 1
else:
pre.append(l)
continue
if state == 1:
p = loop_end_regexp[0].match(l_str)
if p is not None:
reg0 = p.group("reg0")
reg1 = p.group("reg1")
imm = p.group("imm")
state = 2
continue
body.append(l)
continue
if state == 2:
p = loop_end_regexp[1].match(l_str)
if p is not None:
state = 3
continue
body.append(l)
continue
if state == 3:
post.append(l)
continue
if state < 3:
raise FatalParsingException(f"Couldn't identify loop {lbl}")
return pre, body, post, lbl, (reg0, reg1, imm)

class FatalParsingException(Exception):
"""A fatal error happened during instruction parsing"""

class UnknownInstruction(Exception):
"""The parent instruction class for the given object could not be found"""

class UnknownRegister(Exception):
"""The register could not be found"""

class Instruction:

Expand Down
Loading

0 comments on commit 90a78c0

Please sign in to comment.