Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

transformations: (convert-riscv-to-llvm) #2468

Merged
merged 21 commits into from
Nov 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 116 additions & 0 deletions tests/filecheck/transforms/convert_riscv_to_llvm.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
// RUN: xdsl-opt %s -p convert-riscv-to-llvm | filecheck %s
// RUN: xdsl-opt %s -p convert-riscv-to-llvm,reconcile-unrealized-casts,dce | filecheck %s --check-prefix COMPACT


%reg = riscv.li 0 : !riscv.reg
%a0 = riscv.li 0 : !riscv.reg<a0>
%x0 = riscv.get_register : !riscv.reg<zero>

// CHECK: builtin.module {
// CHECK-NEXT: %reg = "llvm.inline_asm"() <{"asm_string" = "li $0, 0", "constraints" = "=r", "asm_dialect" = 0 : i64}> : () -> i32
// CHECK-NEXT: %reg_1 = builtin.unrealized_conversion_cast %reg : i32 to !riscv.reg
// CHECK-NEXT: %a0 = "llvm.inline_asm"() <{"asm_string" = "li $0, 0", "constraints" = "=r", "asm_dialect" = 0 : i64}> : () -> i32
// CHECK-NEXT: %a0_1 = builtin.unrealized_conversion_cast %a0 : i32 to !riscv.reg<a0>
// CHECK-NEXT: %x0 = riscv.get_register : !riscv.reg<zero>

// standard risc-v instructions

%li = riscv.li 0 : !riscv.reg
// CHECK-NEXT: %li = "llvm.inline_asm"() <{"asm_string" = "li $0, 0", "constraints" = "=r", "asm_dialect" = 0 : i64}> : () -> i32
// CHECK-NEXT: %li_1 = builtin.unrealized_conversion_cast %li : i32 to !riscv.reg

%sub = riscv.sub %reg, %reg : (!riscv.reg, !riscv.reg) -> !riscv.reg
// CHECK-NEXT: %sub = builtin.unrealized_conversion_cast %reg_1 : !riscv.reg to i32
// CHECK-NEXT: %sub_1 = builtin.unrealized_conversion_cast %reg_1 : !riscv.reg to i32
// CHECK-NEXT: %sub_2 = "llvm.inline_asm"(%sub, %sub_1) <{"asm_string" = "sub $0, $1, $2", "constraints" = "=r,rI,rI", "asm_dialect" = 0 : i64}> : (i32, i32) -> i32
// CHECK-NEXT: %sub_3 = builtin.unrealized_conversion_cast %sub_2 : i32 to !riscv.reg

%div = riscv.div %reg, %reg : (!riscv.reg, !riscv.reg) -> !riscv.reg
// CHECK-NEXT: %div = builtin.unrealized_conversion_cast %reg_1 : !riscv.reg to i32
// CHECK-NEXT: %div_1 = builtin.unrealized_conversion_cast %reg_1 : !riscv.reg to i32
// CHECK-NEXT: %div_2 = "llvm.inline_asm"(%div, %div_1) <{"asm_string" = "div $0, $1, $2", "constraints" = "=r,rI,rI", "asm_dialect" = 0 : i64}> : (i32, i32) -> i32
// CHECK-NEXT: %div_3 = builtin.unrealized_conversion_cast %div_2 : i32 to !riscv.reg

// named riscv registers:

%li_named = riscv.li 0 : !riscv.reg<a0>
// CHECK-NEXT: %li_named = "llvm.inline_asm"() <{"asm_string" = "li $0, 0", "constraints" = "=r", "asm_dialect" = 0 : i64}> : () -> i32
// CHECK-NEXT: %li_named_1 = builtin.unrealized_conversion_cast %li_named : i32 to !riscv.reg<a0>

%sub_named = riscv.sub %a0, %a0 : (!riscv.reg<a0>, !riscv.reg<a0>) -> !riscv.reg
// CHECK-NEXT: %sub_named = builtin.unrealized_conversion_cast %a0_1 : !riscv.reg<a0> to i32
// CHECK-NEXT: %sub_named_1 = builtin.unrealized_conversion_cast %a0_1 : !riscv.reg<a0> to i32
// CHECK-NEXT: %sub_named_2 = "llvm.inline_asm"(%sub_named, %sub_named_1) <{"asm_string" = "sub $0, $1, $2", "constraints" = "=r,rI,rI", "asm_dialect" = 0 : i64}> : (i32, i32) -> i32
// CHECK-NEXT: %sub_named_3 = builtin.unrealized_conversion_cast %sub_named_2 : i32 to !riscv.reg

%div_named = riscv.div %a0, %a0 : (!riscv.reg<a0>, !riscv.reg<a0>) -> !riscv.reg
// CHECK-NEXT: %div_named = builtin.unrealized_conversion_cast %a0_1 : !riscv.reg<a0> to i32
// CHECK-NEXT: %div_named_1 = builtin.unrealized_conversion_cast %a0_1 : !riscv.reg<a0> to i32
// CHECK-NEXT: %div_named_2 = "llvm.inline_asm"(%div_named, %div_named_1) <{"asm_string" = "div $0, $1, $2", "constraints" = "=r,rI,rI", "asm_dialect" = 0 : i64}> : (i32, i32) -> i32
// CHECK-NEXT: %div_named_3 = builtin.unrealized_conversion_cast %div_named_2 : i32 to !riscv.reg


// csr instructions

%csrss = riscv.csrrs %x0, 3860, "r" : (!riscv.reg<zero>) -> !riscv.reg
// CHECK-NEXT: %csrss = "llvm.inline_asm"() <{"asm_string" = "csrrs $0, 3860, x0", "constraints" = "=r", "asm_dialect" = 0 : i64}> : () -> i32
// CHECK-NEXT: %csrss_1 = builtin.unrealized_conversion_cast %csrss : i32 to !riscv.reg

%csrrs = riscv.csrrs %x0, 1986 : (!riscv.reg<zero>) -> !riscv.reg<zero>
// CHECK-NEXT: %csrrs = riscv.get_register : !riscv.reg<zero>
// CHECK-NEXT: "llvm.inline_asm"() <{"asm_string" = "csrrs x0, 1986, x0", "constraints" = "", "asm_dialect" = 0 : i64}> : () -> ()

%csrrci = riscv.csrrci 1984, 1 : () -> !riscv.reg
// CHECK-NEXT: %csrrci = "llvm.inline_asm"() <{"asm_string" = "csrrci $0, 1984, 1", "constraints" = "=r", "asm_dialect" = 0 : i64}> : () -> i32
// CHECK-NEXT: %csrrci_1 = builtin.unrealized_conversion_cast %csrrci : i32 to !riscv.reg


// custom snitch instructions

riscv_snitch.dmsrc %reg, %reg : (!riscv.reg, !riscv.reg) -> ()
// CHECK-NEXT: %0 = builtin.unrealized_conversion_cast %reg_1 : !riscv.reg to i32
// CHECK-NEXT: %1 = builtin.unrealized_conversion_cast %reg_1 : !riscv.reg to i32
// CHECK-NEXT: "llvm.inline_asm"(%0, %1) <{"asm_string" = ".insn r 0x2b, 0, 0, x0, $0, $1", "constraints" = "rI,rI", "asm_dialect" = 0 : i64}> : (i32, i32) -> ()

riscv_snitch.dmdst %reg, %reg : (!riscv.reg, !riscv.reg) -> ()
// CHECK-NEXT: %2 = builtin.unrealized_conversion_cast %reg_1 : !riscv.reg to i32
// CHECK-NEXT: %3 = builtin.unrealized_conversion_cast %reg_1 : !riscv.reg to i32
// CHECK-NEXT: "llvm.inline_asm"(%2, %3) <{"asm_string" = ".insn r 0x2b, 0, 1, x0, $0, $1", "constraints" = "rI,rI", "asm_dialect" = 0 : i64}> : (i32, i32) -> ()

riscv_snitch.dmstr %reg, %reg : (!riscv.reg, !riscv.reg) -> ()
// CHECK-NEXT: %4 = builtin.unrealized_conversion_cast %reg_1 : !riscv.reg to i32
// CHECK-NEXT: %5 = builtin.unrealized_conversion_cast %reg_1 : !riscv.reg to i32
// CHECK-NEXT: "llvm.inline_asm"(%4, %5) <{"asm_string" = ".insn r 0x2b, 0, 6, x0, $0, $1", "constraints" = "rI,rI", "asm_dialect" = 0 : i64}> : (i32, i32) -> ()

riscv_snitch.dmrep %reg : (!riscv.reg) -> ()
// CHECK-NEXT: %6 = builtin.unrealized_conversion_cast %reg_1 : !riscv.reg to i32
// CHECK-NEXT: "llvm.inline_asm"(%6) <{"asm_string" = ".insn r 0x2b, 0, 7, x0, $0, x0", "constraints" = "rI", "asm_dialect" = 0 : i64}> : (i32) -> ()

%dmcpyi = riscv_snitch.dmcpyi %reg, 2 : (!riscv.reg) -> !riscv.reg
// CHECK-NEXT: %dmcpyi = builtin.unrealized_conversion_cast %reg_1 : !riscv.reg to i32
// CHECK-NEXT: %dmcpyi_1 = "llvm.inline_asm"(%dmcpyi) <{"asm_string" = ".insn r 0x2b, 0, 2, $0, $1, 2", "constraints" = "=r,rI", "asm_dialect" = 0 : i64}> : (i32) -> i32
// CHECK-NEXT: %dmcpyi_2 = builtin.unrealized_conversion_cast %dmcpyi_1 : i32 to !riscv.reg


// ------------------------------------------------------- //
// compact representation after reconciling casts and DCE: //
// ------------------------------------------------------- //

// COMPACT: builtin.module {
// COMPACT-NEXT: %reg = "llvm.inline_asm"() <{"asm_string" = "li $0, 0", "constraints" = "=r", "asm_dialect" = 0 : i64}> : () -> i32
// COMPACT-NEXT: %a0 = "llvm.inline_asm"() <{"asm_string" = "li $0, 0", "constraints" = "=r", "asm_dialect" = 0 : i64}> : () -> i32
// COMPACT-NEXT: %li = "llvm.inline_asm"() <{"asm_string" = "li $0, 0", "constraints" = "=r", "asm_dialect" = 0 : i64}> : () -> i32
// COMPACT-NEXT: %sub = "llvm.inline_asm"(%reg, %reg) <{"asm_string" = "sub $0, $1, $2", "constraints" = "=r,rI,rI", "asm_dialect" = 0 : i64}> : (i32, i32) -> i32
// COMPACT-NEXT: %div = "llvm.inline_asm"(%reg, %reg) <{"asm_string" = "div $0, $1, $2", "constraints" = "=r,rI,rI", "asm_dialect" = 0 : i64}> : (i32, i32) -> i32
// COMPACT-NEXT: %li_named = "llvm.inline_asm"() <{"asm_string" = "li $0, 0", "constraints" = "=r", "asm_dialect" = 0 : i64}> : () -> i32
// COMPACT-NEXT: %sub_named = "llvm.inline_asm"(%a0, %a0) <{"asm_string" = "sub $0, $1, $2", "constraints" = "=r,rI,rI", "asm_dialect" = 0 : i64}> : (i32, i32) -> i32
// COMPACT-NEXT: %div_named = "llvm.inline_asm"(%a0, %a0) <{"asm_string" = "div $0, $1, $2", "constraints" = "=r,rI,rI", "asm_dialect" = 0 : i64}> : (i32, i32) -> i32
// COMPACT-NEXT: %csrss = "llvm.inline_asm"() <{"asm_string" = "csrrs $0, 3860, x0", "constraints" = "=r", "asm_dialect" = 0 : i64}> : () -> i32
// COMPACT-NEXT: "llvm.inline_asm"() <{"asm_string" = "csrrs x0, 1986, x0", "constraints" = "", "asm_dialect" = 0 : i64}> : () -> ()
// COMPACT-NEXT: %csrrci = "llvm.inline_asm"() <{"asm_string" = "csrrci $0, 1984, 1", "constraints" = "=r", "asm_dialect" = 0 : i64}> : () -> i32
// COMPACT-NEXT: "llvm.inline_asm"(%reg, %reg) <{"asm_string" = ".insn r 0x2b, 0, 0, x0, $0, $1", "constraints" = "rI,rI", "asm_dialect" = 0 : i64}> : (i32, i32) -> ()
// COMPACT-NEXT: "llvm.inline_asm"(%reg, %reg) <{"asm_string" = ".insn r 0x2b, 0, 1, x0, $0, $1", "constraints" = "rI,rI", "asm_dialect" = 0 : i64}> : (i32, i32) -> ()
// COMPACT-NEXT: "llvm.inline_asm"(%reg, %reg) <{"asm_string" = ".insn r 0x2b, 0, 6, x0, $0, $1", "constraints" = "rI,rI", "asm_dialect" = 0 : i64}> : (i32, i32) -> ()
// COMPACT-NEXT: "llvm.inline_asm"(%reg) <{"asm_string" = ".insn r 0x2b, 0, 7, x0, $0, x0", "constraints" = "rI", "asm_dialect" = 0 : i64}> : (i32) -> ()
// COMPACT-NEXT: %dmcpyi = "llvm.inline_asm"(%reg) <{"asm_string" = ".insn r 0x2b, 0, 2, $0, $1, 2", "constraints" = "=r,rI", "asm_dialect" = 0 : i64}> : (i32) -> i32
// COMPACT-NEXT: }
6 changes: 6 additions & 0 deletions xdsl/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,11 @@ def get_convert_riscv_scf_to_riscv_cf():

return convert_riscv_scf_to_riscv_cf.ConvertRiscvScfToRiscvCfPass

def get_convert_riscv_to_llvm():
from xdsl.transforms import convert_riscv_to_llvm

return convert_riscv_to_llvm.ConvertRiscvToLLVMPass

def get_convert_scf_to_cf():
from xdsl.transforms import convert_scf_to_cf

Expand Down Expand Up @@ -491,6 +496,7 @@ def get_varith_fuse_repeated_operands():
"convert-qssa-to-qref": get_convert_qssa_to_qref,
"convert-riscv-scf-for-to-frep": get_convert_riscv_scf_for_to_frep,
"convert-riscv-scf-to-riscv-cf": get_convert_riscv_scf_to_riscv_cf,
"convert-riscv-to-llvm": get_convert_riscv_to_llvm,
"convert-scf-to-cf": get_convert_scf_to_cf,
"convert-scf-to-openmp": get_convert_scf_to_openmp,
"convert-scf-to-riscv-scf": get_convert_scf_to_riscv_scf,
Expand Down
155 changes: 155 additions & 0 deletions xdsl/transforms/convert_riscv_to_llvm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
from dataclasses import dataclass

from xdsl.context import MLContext
from xdsl.dialects import builtin, riscv
from xdsl.dialects.builtin import IntAttr, IntegerAttr, UnrealizedConversionCastOp
from xdsl.dialects.llvm import InlineAsmOp
from xdsl.dialects.riscv import IntRegisterType, RISCVInstruction
from xdsl.ir import Operation, OpResult, SSAValue
from xdsl.passes import ModulePass
from xdsl.pattern_rewriter import (
PatternRewriter,
PatternRewriteWalker,
RewritePattern,
op_type_rewrite_pattern,
)
from xdsl.traits import HasInsnRepresentation
from xdsl.utils.exceptions import DiagnosticException


@dataclass(frozen=True)
class RiscvToLLVMPattern(RewritePattern):
xlen: int

@op_type_rewrite_pattern
def match_and_rewrite(self, op: RISCVInstruction, rewriter: PatternRewriter):
ops_to_insert: list[Operation] = []

# inputs for the llvm inline asm op
assembly_args_str: list[str] = []
constraints: list[str] = []
inputs: list[SSAValue | OpResult] = []
# number of results produced from the inline assembly op
# all results are considered to be XLEN long integers
num_results: int = 0
# keep track which results are taken from the inline assembly op and which one are "other":
# int -> index into the inline asm op result list
# ssa val -> use this ssa value instead (e.g. when the op "returns" the zero register)
result_map: list[int | SSAValue] = []

# populate assembly_args_str and constraints
for arg in op.assembly_line_args():
# ssa value used as an output operand
match arg:
case OpResult() if arg.owner is op and isinstance(
arg.type, IntRegisterType
):
# if we are storing to zero, we can't produce a result, so replace result by
# a get_register for the zero registers.
if arg.type.is_allocated and arg.type.index == IntAttr(0):
assembly_args_str.append("x0")
ops_to_insert.append(zero := riscv.GetRegisterOp(arg.type))
# map final result to an existing SSA value
result_map.append(zero.res)
continue
# all other registers are treated as if they were unallocated
# meaning we cast them to i32 and pass values to the op
assembly_args_str.append(f"${len(inputs) + num_results}")
constraints.append("=r")
# map final result to a result of the inline asm op
result_map.append(num_results)
num_results += 1

case SSAValue() if isinstance(arg.type, IntRegisterType):
# if the input is allocated to a zero register, use that register
# other allocated registers are treaded as if they were unallocated
if arg.type.is_allocated and arg.type.index == IntAttr(0):
assembly_args_str.append("x0")
# otherwise we need to get the value from the SSA value
else:
conversion_op = UnrealizedConversionCastOp.get(
[arg], [builtin.i32]
)
ops_to_insert.append(conversion_op)
inputs.append(conversion_op.outputs[0])
constraints.append("rI")
assembly_args_str.append(f"${len(inputs) + num_results - 1}")

case IntegerAttr():
# constant value used as an immediate
assembly_args_str.append(str(arg.value.data))

case _:
raise DiagnosticException(
"unsupported argument for conversion to an llvm inline assembly instruction"
)

# construct asm_string
instruction_name = op.assembly_instruction_name()

# check if the operation has a custom insn string (for compatibility reasons)
custom_insn = op.get_trait(HasInsnRepresentation)
if custom_insn is not None:
# generate custom insn inline assembly instruction
insn_str = custom_insn.get_insn(op)
asm_string = insn_str.format(*assembly_args_str)
else:
# generate generic riscv inline assembly instruction
asm_string = instruction_name + " " + ", ".join(assembly_args_str)

# construct constraints_string
constraints_string = ",".join(constraints)

# construct llvm inline asm op
register_width_int = builtin.IntegerType(self.xlen)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not index type?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

llvm dialect doesn't like index type

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

!!

ops_to_insert.append(
new_op := InlineAsmOp(
asm_string,
constraints_string,
inputs,
[register_width_int] * num_results,
)
)
op_results = new_op.results

# cast output back to original type if necessary
if num_results:
ops_to_insert.append(
output_op := UnrealizedConversionCastOp.get(
new_op.results, [r.type for r in op.results]
superlopuh marked this conversation as resolved.
Show resolved Hide resolved
)
)
op_results = output_op.results

# map results back using result_map
rewriter.replace_matched_op(
ops_to_insert,
[op_results[i] if isinstance(i, int) else i for i in result_map],
)


class ConvertRiscvToLLVMPass(ModulePass):
"""
Convert RISC-V instructions to LLVM inline assembly. This allows for the use
of an LLVM backend instead of direct RISC-V assembly generation. Additionally,
custom ops are implemented using .insn directives, to avoid the need for a
custom LLVM backend.

Only integer register types are supported. Specify register width through the
xlen pass argument.

Due to the nature of inline assembly operations, this behaviour is very flaky
for code that has been register allocated, and will most likely break for all
non-trivial register allocated code.

This pass handles register allocated operations by discarding allocated registers.
This breaks as soon as the riscv dialect code has non-SSA def-use chains (e.g.
through get_register ops).
"""

name = "convert-riscv-to-llvm"

xlen: int = 32

def apply(self, ctx: MLContext, op: builtin.ModuleOp) -> None:
PatternRewriteWalker(RiscvToLLVMPattern(self.xlen)).rewrite_module(op)
Loading