Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More general structure for loop parsing #91

Merged
merged 8 commits into from
Nov 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions example.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,37 @@ def core(self, slothy):
slothy.config.inputs_are_outputs = True
slothy.optimize_loop("start")

class LoopLe(Example):
def __init__(self, var="", arch=Arch_Armv81M, target=Target_CortexM55r1):
name = "loop_le"
infile = name

if var != "":
name += f"_{var}"
infile += f"_{var}"
name += f"_{target_label_dict[target]}"

super().__init__(infile, name, rename=True, arch=arch, target=target)

def core(self,slothy):
slothy.config.variable_size=True
slothy.optimize_loop("start")

class AArch64LoopSubs(Example):
def __init__(self, var="", arch=AArch64_Neon, target=Target_CortexA55):
name = "aarch64_loop_subs"
infile = name

if var != "":
name += f"_{var}"
infile += f"_{var}"
name += f"_{target_label_dict[target]}"

super().__init__(infile, name, rename=True, arch=arch, target=target)

def core(self,slothy):
slothy.config.variable_size=True
slothy.optimize_loop("start")

class CRT(Example):
def __init__(self):
Expand Down Expand Up @@ -1394,6 +1425,10 @@ def main():
AArch64Example2(),
AArch64Example2(target=Target_CortexA72),

# Loop examples
AArch64LoopSubs(),
LoopLe(),

CRT(),

ntt_n256_l6_s32("bar"),
Expand Down
9 changes: 9 additions & 0 deletions examples/naive/aarch64/aarch64_loop_subs.s
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
count .req x2

mov count, #16
start:

nop

subs count, count, #1
cbnz count, start
4 changes: 4 additions & 0 deletions examples/naive/loop_le.s
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
mov lr, #16
start:
nop
le lr, start
20 changes: 12 additions & 8 deletions slothy/core/slothy.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,12 +367,11 @@ def fusion_loop(self, loop_lbl):
"""Run fusion callbacks on loop body"""
logger = self.logger.getChild(f"ssa_loop_{loop_lbl}")

pre , body, post, _, other_data = \
pre , body, post, _, other_data, loop = \
self.arch.Loop.extract(self.source, loop_lbl)
(loop_cnt, _, _) = other_data
loop_cnt = other_data['cnt']
indentation = AsmHelper.find_indentation(body)

loop = self.arch.Loop(lbl_start=loop_lbl)
body_ssa = SourceLine.read_multiline(loop.start(loop_cnt)) + \
SourceLine.apply_indentation(self._fusion_core(pre, body, logger), indentation) + \
SourceLine.read_multiline(loop.end(other_data))
Expand All @@ -394,13 +393,15 @@ def fusion_loop(self, loop_lbl):
assert SourceLine.is_source(self.source)

def optimize_loop(self, loop_lbl, postamble_label=None):
"""Optimize the loop starting at a given label"""
"""Optimize the loop starting at a given label
The postamble_label marks the end of the loop kernel.
"""

logger = self.logger.getChild(loop_lbl)

early, body, late, _, other_data = \
early, body, late, _, other_data, loop = \
self.arch.Loop.extract(self.source, loop_lbl)
(loop_cnt, _, _) = other_data
loop_cnt = other_data['cnt']

# Check if the body has a dominant indentation
indentation = AsmHelper.find_indentation(body)
Expand Down Expand Up @@ -464,7 +465,6 @@ def loop_lbl_iter(i):
for i in range(1, num_exceptional):
optimized_code += indented(self.arch.Branch.if_equal(loop_cnt, i, loop_lbl_iter(i)))

loop = self.arch.Loop(lbl_start=loop_lbl)
hanno-becker marked this conversation as resolved.
Show resolved Hide resolved
optimized_code += indented(preamble_code)

if self.config.sw_pipelining.unknown_iteration_count:
Expand All @@ -479,7 +479,11 @@ def loop_lbl_iter(i):
indentation=self.config.indentation,
fixup=num_exceptional,
unroll=self.config.sw_pipelining.unroll,
jump_if_empty=jump_if_empty))
jump_if_empty=jump_if_empty,
preamble_code=preamble_code,
body_code=kernel_code,
postamble_code=postamble_code,
register_aliases=c.register_aliases))
optimized_code += indented(kernel_code)
optimized_code += SourceLine.read_multiline(loop.end(other_data,
indentation=self.config.indentation))
Expand Down
106 changes: 106 additions & 0 deletions slothy/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@
import re
import subprocess
import logging
from abc import ABC, abstractmethod

from slothy.targets.common import *

class SourceLine:
"""Representation of a single line of source code"""
Expand Down Expand Up @@ -1089,3 +1092,106 @@ def forward_to_file(self, log_label, filename, lvl=logging.DEBUG):
h.setLevel(lvl)
l.addHandler(h)
self.forward(l)


class Loop(ABC):
def __init__(self, lbl_start="1", lbl_end="2", loop_init="lr"):
self.lbl_start = lbl_start
self.lbl_end = lbl_end
self.loop_init = loop_init
self.additional_data = {}

@abstractmethod
def start(self, loop_cnt, indentation=0, fixup=0, unroll=1, jump_if_empty=None):
"""Emit starting instruction(s) and jump label for loop"""
pass

@abstractmethod
def end(self, other, indentation=0):
"""Emit compare-and-branch at the end of the loop"""
pass

def _extract(self, source, lbl):
"""Locate a loop with start label `lbl` in `source`.```"""
assert isinstance(source, list)

# additional_data will be assigned according to the capture groups from
# loop_end_regexp.
pre = []
body = []
post = []
# candidate lines for the end of the loop
loop_end_candidates = []
loop_lbl_regexp_txt = self.lbl_regex
loop_lbl_regexp = re.compile(loop_lbl_regexp_txt)

# end_regex shall contain group cnt as the counter variable
loop_end_regexp_txt = self.end_regex
loop_end_regexp = [re.compile(txt) for txt in loop_end_regexp_txt]
lines = iter(source)
l = None
keep = False
state = 0 # 0: haven't found loop yet, 1: extracting loop, 2: after loop
loop_end_ctr = 0
while True:
if not keep:
l = next(lines, None)
keep = False
if l is None:
break
l_str = l.text
assert isinstance(l, str) is False
if state == 0:
p = loop_lbl_regexp.match(l_str)
if p is not None and p.group("label") == lbl:
l = l.copy().set_text(p.group("remainder"))
keep = True
state = 1
else:
pre.append(l)
continue
if state == 1:
p = loop_end_regexp[loop_end_ctr].match(l_str)
if p is not None:
# Case: We may have encountered part of the loop end
# collect all named groups
self.additional_data = self.additional_data | p.groupdict()
loop_end_ctr += 1
loop_end_candidates.append(l)
if loop_end_ctr == len(loop_end_regexp):
state = 2
continue
elif loop_end_ctr > 0 and l_str != "":
# Case: The sequence of loop end candidates was interrupted
# i.e., we found a false-positive or this is not a proper loop

# The loop end candidates are not part of the loop, meaning
# they belonged to the body
body += loop_end_candidates
self.additional_data = {}
loop_end_ctr = 0
loop_end_candidates = []
body.append(l)
continue
if state == 2:
loop_end_candidates = []
post.append(l)
continue
if state < 2:
raise FatalParsingException(f"Couldn't identify loop {lbl}")
return pre, body, post, lbl, self.additional_data

@staticmethod
def extract(source, lbl):
for loop_type in Loop.__subclasses__():
try:
l = loop_type(lbl)
# concatenate the extracted loop with an instance of the
# identified loop_type, (l,) creates a tuple with one element to
# merge with the tuple retuned by _extract
return l._extract(source, lbl) + (l,)
except FatalParsingException:
logging.debug("Parsing loop type '%s'failed", loop_type)
pass

raise FatalParsingException(f"Couldn't identify loop {lbl}")
121 changes: 29 additions & 92 deletions slothy/targets/aarch64/aarch64_neon.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ class which generates instruction parsers and writers from instruction templates

from sympy import simplify

from slothy.targets.common import *
from slothy.helper import Loop

arch_name = "Arm_AArch64"
llvm_mca_arch = "aarch64"

Expand Down Expand Up @@ -169,119 +172,53 @@ def unconditional(lbl):
"""Emit unconditional branch"""
yield f"b {lbl}"

class Loop:
"""Helper functions for parsing and writing simple loops in AArch64

TODO: Generalize; current implementation too specific about shape of loop"""

def __init__(self, lbl_start="1", lbl_end="2", loop_init="lr"):
self.lbl_start = lbl_start
self.lbl_end = lbl_end
self.loop_init = loop_init
class SubsLoop(Loop):
"""
Loop ending in a flag setting subtraction and a branch.

def start(self, loop_cnt, indentation=0, fixup=0, unroll=1, jump_if_empty=None):
Example:
```
loop_lbl:
{code}
sub[s] <cnt>, <cnt>, #<imm>
(cbnz|bnz|bne) <cnt>, loop_lbl
```
where cnt is the loop counter in lr.
"""
def __init__(self, lbl="lbl", lbl_start="1", lbl_end="2", loop_init="lr") -> None:
hanno-becker marked this conversation as resolved.
Show resolved Hide resolved
super().__init__(lbl_start=lbl_start, lbl_end=lbl_end, loop_init=loop_init)
# The group naming in the regex should be consistent; give same group
# names to the same registers
self.lbl_regex = r"^\s*(?P<label>\w+)\s*:(?P<remainder>.*)$"
self.end_regex = (r"^\s*sub[s]?\s+(?P<cnt>\w+),\s*(?P<reg1>\w+),\s*#(?P<imm>\d+)",
rf"^\s*(cbnz|bnz|bne)\s+(?P<cnt>\w+),\s*{lbl}")

def start(self, loop_cnt, indentation=0, fixup=0, unroll=1, jump_if_empty=None, preamble_code=None, body_code=None, postamble_code=None, register_aliases=None):
"""Emit starting instruction(s) and jump label for loop"""
indent = ' ' * indentation
if unroll > 1:
assert unroll in [1,2,4,8,16,32]
yield f"{indent}lsr {loop_cnt}, {loop_cnt}, #{int(math.log2(unroll))}"
if fixup != 0:
# In case the immediate is >1, we need to scale the fixup. This
# allows for loops that do not use an increment of 1
fixup *= self.additional_data['imm']
yield f"{indent}sub {loop_cnt}, {loop_cnt}, #{fixup}"
if jump_if_empty is not None:
yield f"cbz {loop_cnt}, {jump_if_empty}"
yield f"{self.lbl_start}:"

def end(self, other, indentation=0):
"""Emit compare-and-branch at the end of the loop"""
(reg0, reg1, imm) = other
indent = ' ' * indentation
lbl_start = self.lbl_start
if lbl_start.isdigit():
lbl_start += "b"

yield f"{indent}sub {reg0}, {reg1}, {imm}"
yield f"{indent}cbnz {reg0}, {lbl_start}"

@staticmethod
def extract(source, lbl):
"""Locate a loop with start label `lbl` in `source`.

We currently only support the following loop forms:
yield f"{indent}sub {other['cnt']}, {other['cnt']}, {other['imm']}"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should use the same format that the original loop had? With/without flag, and potentially using cbnz, bnz, bne?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is pre-existing, so let's not block the PR because of it.

yield f"{indent}cbnz {other['cnt']}, {lbl_start}"

```
loop_lbl:
{code}
sub[s] <cnt>, <cnt>, #1
(cbnz|bnz|bne) <cnt>, loop_lbl
```

"""
assert isinstance(source, list)

pre = []
body = []
post = []
loop_lbl_regexp_txt = r"^\s*(?P<label>\w+)\s*:(?P<remainder>.*)$"
loop_lbl_regexp = re.compile(loop_lbl_regexp_txt)

# TODO: Allow other forms of looping

loop_end_regexp_txt = (r"^\s*sub[s]?\s+(?P<reg0>\w+),\s*(?P<reg1>\w+),\s*(?P<imm>#1)",
rf"^\s*(cbnz|bnz|bne)\s+(?P<reg0>\w+),\s*{lbl}")
loop_end_regexp = [re.compile(txt) for txt in loop_end_regexp_txt]
lines = iter(source)
l = None
keep = False
state = 0 # 0: haven't found loop yet, 1: extracting loop, 2: after loop
while True:
if not keep:
l = next(lines, None)
keep = False
if l is None:
break
l_str = l.text
assert isinstance(l, str) is False
if state == 0:
p = loop_lbl_regexp.match(l_str)
if p is not None and p.group("label") == lbl:
l = l.copy().set_text(p.group("remainder"))
keep = True
state = 1
else:
pre.append(l)
continue
if state == 1:
p = loop_end_regexp[0].match(l_str)
if p is not None:
reg0 = p.group("reg0")
reg1 = p.group("reg1")
imm = p.group("imm")
state = 2
continue
body.append(l)
continue
if state == 2:
p = loop_end_regexp[1].match(l_str)
if p is not None:
state = 3
continue
body.append(l)
continue
if state == 3:
post.append(l)
continue
if state < 3:
raise FatalParsingException(f"Couldn't identify loop {lbl}")
return pre, body, post, lbl, (reg0, reg1, imm)

class FatalParsingException(Exception):
"""A fatal error happened during instruction parsing"""

class UnknownInstruction(Exception):
"""The parent instruction class for the given object could not be found"""

class UnknownRegister(Exception):
"""The register could not be found"""

class Instruction:

Expand Down
Loading