From b1a4ab5fb0d8cdfa7cfd6d9d07c8f8e6004afd11 Mon Sep 17 00:00:00 2001 From: Amin Abdulrahman Date: Fri, 11 Oct 2024 16:01:21 +0200 Subject: [PATCH 1/8] More general structure for loop parsing --- slothy/core/slothy.py | 10 +-- slothy/targets/aarch64/aarch64_neon.py | 118 ++++++++++++++----------- slothy/targets/arm_v81m/arch_v81m.py | 108 +++++++++++++++------- 3 files changed, 146 insertions(+), 90 deletions(-) diff --git a/slothy/core/slothy.py b/slothy/core/slothy.py index 3fb67a41..bcff4108 100644 --- a/slothy/core/slothy.py +++ b/slothy/core/slothy.py @@ -367,12 +367,11 @@ def fusion_loop(self, loop_lbl): """Run fusion callbacks on loop body""" logger = self.logger.getChild(f"ssa_loop_{loop_lbl}") - pre , body, post, _, other_data = \ + pre , body, post, _, other_data, loop = \ self.arch.Loop.extract(self.source, loop_lbl) - (loop_cnt, _, _) = other_data + loop_cnt = other_data['cnt'] indentation = AsmHelper.find_indentation(body) - loop = self.arch.Loop(lbl_start=loop_lbl) body_ssa = SourceLine.read_multiline(loop.start(loop_cnt)) + \ SourceLine.apply_indentation(self._fusion_core(pre, body, logger), indentation) + \ SourceLine.read_multiline(loop.end(other_data)) @@ -398,9 +397,9 @@ def optimize_loop(self, loop_lbl, postamble_label=None): logger = self.logger.getChild(loop_lbl) - early, body, late, _, other_data = \ + early, body, late, _, other_data, loop = \ self.arch.Loop.extract(self.source, loop_lbl) - (loop_cnt, _, _) = other_data + loop_cnt = other_data['cnt'] # Check if the body has a dominant indentation indentation = AsmHelper.find_indentation(body) @@ -464,7 +463,6 @@ def loop_lbl_iter(i): for i in range(1, num_exceptional): optimized_code += indented(self.arch.Branch.if_equal(loop_cnt, i, loop_lbl_iter(i))) - loop = self.arch.Loop(lbl_start=loop_lbl) optimized_code += indented(preamble_code) if self.config.sw_pipelining.unknown_iteration_count: diff --git a/slothy/targets/aarch64/aarch64_neon.py b/slothy/targets/aarch64/aarch64_neon.py index b2e8f36f..fad73cc0 100644 --- a/slothy/targets/aarch64/aarch64_neon.py +++ b/slothy/targets/aarch64/aarch64_neon.py @@ -43,6 +43,7 @@ class which generates instruction parsers and writers from instruction templates import math from enum import Enum from functools import cache +from abc import ABC, abstractmethod from sympy import simplify @@ -169,70 +170,46 @@ def unconditional(lbl): """Emit unconditional branch""" yield f"b {lbl}" -class Loop: - """Helper functions for parsing and writing simple loops in AArch64 - - TODO: Generalize; current implementation too specific about shape of loop""" - +class Loop(ABC): def __init__(self, lbl_start="1", lbl_end="2", loop_init="lr"): self.lbl_start = lbl_start self.lbl_end = lbl_end self.loop_init = loop_init + @abstractmethod def start(self, loop_cnt, indentation=0, fixup=0, unroll=1, jump_if_empty=None): """Emit starting instruction(s) and jump label for loop""" - indent = ' ' * indentation - if unroll > 1: - assert unroll in [1,2,4,8,16,32] - yield f"{indent}lsr {loop_cnt}, {loop_cnt}, #{int(math.log2(unroll))}" - if fixup != 0: - yield f"{indent}sub {loop_cnt}, {loop_cnt}, #{fixup}" - if jump_if_empty is not None: - yield f"cbz {loop_cnt}, {jump_if_empty}" - yield f"{self.lbl_start}:" + # TODO: Use different type of fixup for cmp vs. subs loops + pass + @abstractmethod def end(self, other, indentation=0): """Emit compare-and-branch at the end of the loop""" - (reg0, reg1, imm) = other - indent = ' ' * indentation - lbl_start = self.lbl_start - if lbl_start.isdigit(): - lbl_start += "b" - - yield f"{indent}sub {reg0}, {reg1}, {imm}" - yield f"{indent}cbnz {reg0}, {lbl_start}" - - @staticmethod - def extract(source, lbl): - """Locate a loop with start label `lbl` in `source`. - - We currently only support the following loop forms: - - ``` - loop_lbl: - {code} - sub[s] , , #1 - (cbnz|bnz|bne) , loop_lbl - ``` + pass + + def _extract(self, source, lbl): + """Locate a loop with start label `lbl` in `source`.``` """ assert isinstance(source, list) + + additional_data = None pre = [] body = [] post = [] - loop_lbl_regexp_txt = r"^\s*(?P