diff --git a/tests/filecheck/transforms/convert_riscv_to_llvm.mlir b/tests/filecheck/transforms/convert_riscv_to_llvm.mlir new file mode 100644 index 0000000000..401a0487ee --- /dev/null +++ b/tests/filecheck/transforms/convert_riscv_to_llvm.mlir @@ -0,0 +1,116 @@ +// RUN: xdsl-opt %s -p convert-riscv-to-llvm | filecheck %s +// RUN: xdsl-opt %s -p convert-riscv-to-llvm,reconcile-unrealized-casts,dce | filecheck %s --check-prefix COMPACT + + +%reg = riscv.li 0 : !riscv.reg +%a0 = riscv.li 0 : !riscv.reg +%x0 = riscv.get_register : !riscv.reg + +// CHECK: builtin.module { +// CHECK-NEXT: %reg = "llvm.inline_asm"() <{"asm_string" = "li $0, 0", "constraints" = "=r", "asm_dialect" = 0 : i64}> : () -> i32 +// CHECK-NEXT: %reg_1 = builtin.unrealized_conversion_cast %reg : i32 to !riscv.reg +// CHECK-NEXT: %a0 = "llvm.inline_asm"() <{"asm_string" = "li $0, 0", "constraints" = "=r", "asm_dialect" = 0 : i64}> : () -> i32 +// CHECK-NEXT: %a0_1 = builtin.unrealized_conversion_cast %a0 : i32 to !riscv.reg +// CHECK-NEXT: %x0 = riscv.get_register : !riscv.reg + +// standard risc-v instructions + +%li = riscv.li 0 : !riscv.reg +// CHECK-NEXT: %li = "llvm.inline_asm"() <{"asm_string" = "li $0, 0", "constraints" = "=r", "asm_dialect" = 0 : i64}> : () -> i32 +// CHECK-NEXT: %li_1 = builtin.unrealized_conversion_cast %li : i32 to !riscv.reg + +%sub = riscv.sub %reg, %reg : (!riscv.reg, !riscv.reg) -> !riscv.reg +// CHECK-NEXT: %sub = builtin.unrealized_conversion_cast %reg_1 : !riscv.reg to i32 +// CHECK-NEXT: %sub_1 = builtin.unrealized_conversion_cast %reg_1 : !riscv.reg to i32 +// CHECK-NEXT: %sub_2 = "llvm.inline_asm"(%sub, %sub_1) <{"asm_string" = "sub $0, $1, $2", "constraints" = "=r,rI,rI", "asm_dialect" = 0 : i64}> : (i32, i32) -> i32 +// CHECK-NEXT: %sub_3 = builtin.unrealized_conversion_cast %sub_2 : i32 to !riscv.reg + +%div = riscv.div %reg, %reg : (!riscv.reg, !riscv.reg) -> !riscv.reg +// CHECK-NEXT: %div = builtin.unrealized_conversion_cast %reg_1 : !riscv.reg to i32 +// CHECK-NEXT: %div_1 = builtin.unrealized_conversion_cast %reg_1 : !riscv.reg to i32 +// CHECK-NEXT: %div_2 = "llvm.inline_asm"(%div, %div_1) <{"asm_string" = "div $0, $1, $2", "constraints" = "=r,rI,rI", "asm_dialect" = 0 : i64}> : (i32, i32) -> i32 +// CHECK-NEXT: %div_3 = builtin.unrealized_conversion_cast %div_2 : i32 to !riscv.reg + +// named riscv registers: + +%li_named = riscv.li 0 : !riscv.reg +// CHECK-NEXT: %li_named = "llvm.inline_asm"() <{"asm_string" = "li $0, 0", "constraints" = "=r", "asm_dialect" = 0 : i64}> : () -> i32 +// CHECK-NEXT: %li_named_1 = builtin.unrealized_conversion_cast %li_named : i32 to !riscv.reg + +%sub_named = riscv.sub %a0, %a0 : (!riscv.reg, !riscv.reg) -> !riscv.reg +// CHECK-NEXT: %sub_named = builtin.unrealized_conversion_cast %a0_1 : !riscv.reg to i32 +// CHECK-NEXT: %sub_named_1 = builtin.unrealized_conversion_cast %a0_1 : !riscv.reg to i32 +// CHECK-NEXT: %sub_named_2 = "llvm.inline_asm"(%sub_named, %sub_named_1) <{"asm_string" = "sub $0, $1, $2", "constraints" = "=r,rI,rI", "asm_dialect" = 0 : i64}> : (i32, i32) -> i32 +// CHECK-NEXT: %sub_named_3 = builtin.unrealized_conversion_cast %sub_named_2 : i32 to !riscv.reg + +%div_named = riscv.div %a0, %a0 : (!riscv.reg, !riscv.reg) -> !riscv.reg +// CHECK-NEXT: %div_named = builtin.unrealized_conversion_cast %a0_1 : !riscv.reg to i32 +// CHECK-NEXT: %div_named_1 = builtin.unrealized_conversion_cast %a0_1 : !riscv.reg to i32 +// CHECK-NEXT: %div_named_2 = "llvm.inline_asm"(%div_named, %div_named_1) <{"asm_string" = "div $0, $1, $2", "constraints" = "=r,rI,rI", "asm_dialect" = 0 : i64}> : (i32, i32) -> i32 +// CHECK-NEXT: %div_named_3 = builtin.unrealized_conversion_cast %div_named_2 : i32 to !riscv.reg + + +// csr instructions + +%csrss = riscv.csrrs %x0, 3860, "r" : (!riscv.reg) -> !riscv.reg +// CHECK-NEXT: %csrss = "llvm.inline_asm"() <{"asm_string" = "csrrs $0, 3860, x0", "constraints" = "=r", "asm_dialect" = 0 : i64}> : () -> i32 +// CHECK-NEXT: %csrss_1 = builtin.unrealized_conversion_cast %csrss : i32 to !riscv.reg + +%csrrs = riscv.csrrs %x0, 1986 : (!riscv.reg) -> !riscv.reg +// CHECK-NEXT: %csrrs = riscv.get_register : !riscv.reg +// CHECK-NEXT: "llvm.inline_asm"() <{"asm_string" = "csrrs x0, 1986, x0", "constraints" = "", "asm_dialect" = 0 : i64}> : () -> () + +%csrrci = riscv.csrrci 1984, 1 : () -> !riscv.reg +// CHECK-NEXT: %csrrci = "llvm.inline_asm"() <{"asm_string" = "csrrci $0, 1984, 1", "constraints" = "=r", "asm_dialect" = 0 : i64}> : () -> i32 +// CHECK-NEXT: %csrrci_1 = builtin.unrealized_conversion_cast %csrrci : i32 to !riscv.reg + + +// custom snitch instructions + +riscv_snitch.dmsrc %reg, %reg : (!riscv.reg, !riscv.reg) -> () +// CHECK-NEXT: %0 = builtin.unrealized_conversion_cast %reg_1 : !riscv.reg to i32 +// CHECK-NEXT: %1 = builtin.unrealized_conversion_cast %reg_1 : !riscv.reg to i32 +// CHECK-NEXT: "llvm.inline_asm"(%0, %1) <{"asm_string" = ".insn r 0x2b, 0, 0, x0, $0, $1", "constraints" = "rI,rI", "asm_dialect" = 0 : i64}> : (i32, i32) -> () + +riscv_snitch.dmdst %reg, %reg : (!riscv.reg, !riscv.reg) -> () +// CHECK-NEXT: %2 = builtin.unrealized_conversion_cast %reg_1 : !riscv.reg to i32 +// CHECK-NEXT: %3 = builtin.unrealized_conversion_cast %reg_1 : !riscv.reg to i32 +// CHECK-NEXT: "llvm.inline_asm"(%2, %3) <{"asm_string" = ".insn r 0x2b, 0, 1, x0, $0, $1", "constraints" = "rI,rI", "asm_dialect" = 0 : i64}> : (i32, i32) -> () + +riscv_snitch.dmstr %reg, %reg : (!riscv.reg, !riscv.reg) -> () +// CHECK-NEXT: %4 = builtin.unrealized_conversion_cast %reg_1 : !riscv.reg to i32 +// CHECK-NEXT: %5 = builtin.unrealized_conversion_cast %reg_1 : !riscv.reg to i32 +// CHECK-NEXT: "llvm.inline_asm"(%4, %5) <{"asm_string" = ".insn r 0x2b, 0, 6, x0, $0, $1", "constraints" = "rI,rI", "asm_dialect" = 0 : i64}> : (i32, i32) -> () + +riscv_snitch.dmrep %reg : (!riscv.reg) -> () +// CHECK-NEXT: %6 = builtin.unrealized_conversion_cast %reg_1 : !riscv.reg to i32 +// CHECK-NEXT: "llvm.inline_asm"(%6) <{"asm_string" = ".insn r 0x2b, 0, 7, x0, $0, x0", "constraints" = "rI", "asm_dialect" = 0 : i64}> : (i32) -> () + +%dmcpyi = riscv_snitch.dmcpyi %reg, 2 : (!riscv.reg) -> !riscv.reg +// CHECK-NEXT: %dmcpyi = builtin.unrealized_conversion_cast %reg_1 : !riscv.reg to i32 +// CHECK-NEXT: %dmcpyi_1 = "llvm.inline_asm"(%dmcpyi) <{"asm_string" = ".insn r 0x2b, 0, 2, $0, $1, 2", "constraints" = "=r,rI", "asm_dialect" = 0 : i64}> : (i32) -> i32 +// CHECK-NEXT: %dmcpyi_2 = builtin.unrealized_conversion_cast %dmcpyi_1 : i32 to !riscv.reg + + +// ------------------------------------------------------- // +// compact representation after reconciling casts and DCE: // +// ------------------------------------------------------- // + +// COMPACT: builtin.module { +// COMPACT-NEXT: %reg = "llvm.inline_asm"() <{"asm_string" = "li $0, 0", "constraints" = "=r", "asm_dialect" = 0 : i64}> : () -> i32 +// COMPACT-NEXT: %a0 = "llvm.inline_asm"() <{"asm_string" = "li $0, 0", "constraints" = "=r", "asm_dialect" = 0 : i64}> : () -> i32 +// COMPACT-NEXT: %li = "llvm.inline_asm"() <{"asm_string" = "li $0, 0", "constraints" = "=r", "asm_dialect" = 0 : i64}> : () -> i32 +// COMPACT-NEXT: %sub = "llvm.inline_asm"(%reg, %reg) <{"asm_string" = "sub $0, $1, $2", "constraints" = "=r,rI,rI", "asm_dialect" = 0 : i64}> : (i32, i32) -> i32 +// COMPACT-NEXT: %div = "llvm.inline_asm"(%reg, %reg) <{"asm_string" = "div $0, $1, $2", "constraints" = "=r,rI,rI", "asm_dialect" = 0 : i64}> : (i32, i32) -> i32 +// COMPACT-NEXT: %li_named = "llvm.inline_asm"() <{"asm_string" = "li $0, 0", "constraints" = "=r", "asm_dialect" = 0 : i64}> : () -> i32 +// COMPACT-NEXT: %sub_named = "llvm.inline_asm"(%a0, %a0) <{"asm_string" = "sub $0, $1, $2", "constraints" = "=r,rI,rI", "asm_dialect" = 0 : i64}> : (i32, i32) -> i32 +// COMPACT-NEXT: %div_named = "llvm.inline_asm"(%a0, %a0) <{"asm_string" = "div $0, $1, $2", "constraints" = "=r,rI,rI", "asm_dialect" = 0 : i64}> : (i32, i32) -> i32 +// COMPACT-NEXT: %csrss = "llvm.inline_asm"() <{"asm_string" = "csrrs $0, 3860, x0", "constraints" = "=r", "asm_dialect" = 0 : i64}> : () -> i32 +// COMPACT-NEXT: "llvm.inline_asm"() <{"asm_string" = "csrrs x0, 1986, x0", "constraints" = "", "asm_dialect" = 0 : i64}> : () -> () +// COMPACT-NEXT: %csrrci = "llvm.inline_asm"() <{"asm_string" = "csrrci $0, 1984, 1", "constraints" = "=r", "asm_dialect" = 0 : i64}> : () -> i32 +// COMPACT-NEXT: "llvm.inline_asm"(%reg, %reg) <{"asm_string" = ".insn r 0x2b, 0, 0, x0, $0, $1", "constraints" = "rI,rI", "asm_dialect" = 0 : i64}> : (i32, i32) -> () +// COMPACT-NEXT: "llvm.inline_asm"(%reg, %reg) <{"asm_string" = ".insn r 0x2b, 0, 1, x0, $0, $1", "constraints" = "rI,rI", "asm_dialect" = 0 : i64}> : (i32, i32) -> () +// COMPACT-NEXT: "llvm.inline_asm"(%reg, %reg) <{"asm_string" = ".insn r 0x2b, 0, 6, x0, $0, $1", "constraints" = "rI,rI", "asm_dialect" = 0 : i64}> : (i32, i32) -> () +// COMPACT-NEXT: "llvm.inline_asm"(%reg) <{"asm_string" = ".insn r 0x2b, 0, 7, x0, $0, x0", "constraints" = "rI", "asm_dialect" = 0 : i64}> : (i32) -> () +// COMPACT-NEXT: %dmcpyi = "llvm.inline_asm"(%reg) <{"asm_string" = ".insn r 0x2b, 0, 2, $0, $1, 2", "constraints" = "=r,rI", "asm_dialect" = 0 : i64}> : (i32) -> i32 +// COMPACT-NEXT: } diff --git a/xdsl/transforms/__init__.py b/xdsl/transforms/__init__.py index b19fae4e80..c786193198 100644 --- a/xdsl/transforms/__init__.py +++ b/xdsl/transforms/__init__.py @@ -133,6 +133,11 @@ def get_convert_riscv_scf_to_riscv_cf(): return convert_riscv_scf_to_riscv_cf.ConvertRiscvScfToRiscvCfPass + def get_convert_riscv_to_llvm(): + from xdsl.transforms import convert_riscv_to_llvm + + return convert_riscv_to_llvm.ConvertRiscvToLLVMPass + def get_convert_scf_to_cf(): from xdsl.transforms import convert_scf_to_cf @@ -491,6 +496,7 @@ def get_varith_fuse_repeated_operands(): "convert-qssa-to-qref": get_convert_qssa_to_qref, "convert-riscv-scf-for-to-frep": get_convert_riscv_scf_for_to_frep, "convert-riscv-scf-to-riscv-cf": get_convert_riscv_scf_to_riscv_cf, + "convert-riscv-to-llvm": get_convert_riscv_to_llvm, "convert-scf-to-cf": get_convert_scf_to_cf, "convert-scf-to-openmp": get_convert_scf_to_openmp, "convert-scf-to-riscv-scf": get_convert_scf_to_riscv_scf, diff --git a/xdsl/transforms/convert_riscv_to_llvm.py b/xdsl/transforms/convert_riscv_to_llvm.py new file mode 100644 index 0000000000..d8f4df27cb --- /dev/null +++ b/xdsl/transforms/convert_riscv_to_llvm.py @@ -0,0 +1,155 @@ +from dataclasses import dataclass + +from xdsl.context import MLContext +from xdsl.dialects import builtin, riscv +from xdsl.dialects.builtin import IntAttr, IntegerAttr, UnrealizedConversionCastOp +from xdsl.dialects.llvm import InlineAsmOp +from xdsl.dialects.riscv import IntRegisterType, RISCVInstruction +from xdsl.ir import Operation, OpResult, SSAValue +from xdsl.passes import ModulePass +from xdsl.pattern_rewriter import ( + PatternRewriter, + PatternRewriteWalker, + RewritePattern, + op_type_rewrite_pattern, +) +from xdsl.traits import HasInsnRepresentation +from xdsl.utils.exceptions import DiagnosticException + + +@dataclass(frozen=True) +class RiscvToLLVMPattern(RewritePattern): + xlen: int + + @op_type_rewrite_pattern + def match_and_rewrite(self, op: RISCVInstruction, rewriter: PatternRewriter): + ops_to_insert: list[Operation] = [] + + # inputs for the llvm inline asm op + assembly_args_str: list[str] = [] + constraints: list[str] = [] + inputs: list[SSAValue | OpResult] = [] + # number of results produced from the inline assembly op + # all results are considered to be XLEN long integers + num_results: int = 0 + # keep track which results are taken from the inline assembly op and which one are "other": + # int -> index into the inline asm op result list + # ssa val -> use this ssa value instead (e.g. when the op "returns" the zero register) + result_map: list[int | SSAValue] = [] + + # populate assembly_args_str and constraints + for arg in op.assembly_line_args(): + # ssa value used as an output operand + match arg: + case OpResult() if arg.owner is op and isinstance( + arg.type, IntRegisterType + ): + # if we are storing to zero, we can't produce a result, so replace result by + # a get_register for the zero registers. + if arg.type.is_allocated and arg.type.index == IntAttr(0): + assembly_args_str.append("x0") + ops_to_insert.append(zero := riscv.GetRegisterOp(arg.type)) + # map final result to an existing SSA value + result_map.append(zero.res) + continue + # all other registers are treated as if they were unallocated + # meaning we cast them to i32 and pass values to the op + assembly_args_str.append(f"${len(inputs) + num_results}") + constraints.append("=r") + # map final result to a result of the inline asm op + result_map.append(num_results) + num_results += 1 + + case SSAValue() if isinstance(arg.type, IntRegisterType): + # if the input is allocated to a zero register, use that register + # other allocated registers are treaded as if they were unallocated + if arg.type.is_allocated and arg.type.index == IntAttr(0): + assembly_args_str.append("x0") + # otherwise we need to get the value from the SSA value + else: + conversion_op = UnrealizedConversionCastOp.get( + [arg], [builtin.i32] + ) + ops_to_insert.append(conversion_op) + inputs.append(conversion_op.outputs[0]) + constraints.append("rI") + assembly_args_str.append(f"${len(inputs) + num_results - 1}") + + case IntegerAttr(): + # constant value used as an immediate + assembly_args_str.append(str(arg.value.data)) + + case _: + raise DiagnosticException( + "unsupported argument for conversion to an llvm inline assembly instruction" + ) + + # construct asm_string + instruction_name = op.assembly_instruction_name() + + # check if the operation has a custom insn string (for compatibility reasons) + custom_insn = op.get_trait(HasInsnRepresentation) + if custom_insn is not None: + # generate custom insn inline assembly instruction + insn_str = custom_insn.get_insn(op) + asm_string = insn_str.format(*assembly_args_str) + else: + # generate generic riscv inline assembly instruction + asm_string = instruction_name + " " + ", ".join(assembly_args_str) + + # construct constraints_string + constraints_string = ",".join(constraints) + + # construct llvm inline asm op + register_width_int = builtin.IntegerType(self.xlen) + ops_to_insert.append( + new_op := InlineAsmOp( + asm_string, + constraints_string, + inputs, + [register_width_int] * num_results, + ) + ) + op_results = new_op.results + + # cast output back to original type if necessary + if num_results: + ops_to_insert.append( + output_op := UnrealizedConversionCastOp.get( + new_op.results, [r.type for r in op.results] + ) + ) + op_results = output_op.results + + # map results back using result_map + rewriter.replace_matched_op( + ops_to_insert, + [op_results[i] if isinstance(i, int) else i for i in result_map], + ) + + +class ConvertRiscvToLLVMPass(ModulePass): + """ + Convert RISC-V instructions to LLVM inline assembly. This allows for the use + of an LLVM backend instead of direct RISC-V assembly generation. Additionally, + custom ops are implemented using .insn directives, to avoid the need for a + custom LLVM backend. + + Only integer register types are supported. Specify register width through the + xlen pass argument. + + Due to the nature of inline assembly operations, this behaviour is very flaky + for code that has been register allocated, and will most likely break for all + non-trivial register allocated code. + + This pass handles register allocated operations by discarding allocated registers. + This breaks as soon as the riscv dialect code has non-SSA def-use chains (e.g. + through get_register ops). + """ + + name = "convert-riscv-to-llvm" + + xlen: int = 32 + + def apply(self, ctx: MLContext, op: builtin.ModuleOp) -> None: + PatternRewriteWalker(RiscvToLLVMPattern(self.xlen)).rewrite_module(op)