From 9aa070c83fe0bef9154b7a319f4f764b75248f8c Mon Sep 17 00:00:00 2001 From: Laurence Tratt Date: Thu, 3 Oct 2024 13:20:53 +0100 Subject: [PATCH] Rewrite the optimiser to be a one pass optimiser. [Well, OK the first lie is in the first line. Because we have to do backwards code generation, we actually have to do a second pass. But that's a boring detail, and something that predates this commit.] This commit moves the optimiser into a model that's closer to RPython: rather than multiple forward stages each doing one optimisation, we have a single forward stage that does all optimisations in one go. This is quicker (fairly obviously!) but also simpler, because it means that we can make strong assumptions at the point of the wavefront: no-one can change our knowledge of earlier parts of the trace. The crucial difference between this commit and what came before is that as well as mutating the IR as we go along (e.g. constant folding) we maintain an analysis of what values an instruction can produce, which we gradually refine. Crucially, mutation can only happen at the current instruction, but analysis refinement can happen on previous instructions. Consider a trace along these lines: ``` %0: i8 = load_ti ... ; at this point we know nothing about %0 %1: i8 = add %0, 1i8 ; we still know nothing about %0 %2: i1 = eq %0, 0i8 ; we know that henceforth %0 must be 0i8 guard true %2 %4: I8 = mul %0, %0 ; we know this must equal the constant 0i8 ``` Notice that it's only at the point of the guard that our analysis allows us to start using the knowledge "%0 = 0i8" for optimisations. The code is currently woefully incomplete (but if I fill in anything else, I trip up on the "removing guards tends to kill things" bug that I believe is fixed but not yet merged), I haven't yet bothered porting across one old optimisation (`mul_chain`), and I haven't really got a good API for sharing analysis and mutations (which is currently done inconsistently). All that said, this nudges the optimiser forward just about enough that I think it's worth considering for merging now, and then it can be improved in-tree. --- ykrt/src/compile/jitc_yk/aot_ir.rs | 24 +- ykrt/src/compile/jitc_yk/codegen/x64/mod.rs | 12 +- ykrt/src/compile/jitc_yk/jit_ir/mod.rs | 27 +- ykrt/src/compile/jitc_yk/opt/analyse.rs | 82 ++++ ykrt/src/compile/jitc_yk/opt/mod.rs | 494 +++++++++++++++++++- ykrt/src/compile/jitc_yk/opt/simple.rs | 429 ----------------- 6 files changed, 611 insertions(+), 457 deletions(-) create mode 100644 ykrt/src/compile/jitc_yk/opt/analyse.rs delete mode 100644 ykrt/src/compile/jitc_yk/opt/simple.rs diff --git a/ykrt/src/compile/jitc_yk/aot_ir.rs b/ykrt/src/compile/jitc_yk/aot_ir.rs index 00a08ec2a..3c6cb8d55 100644 --- a/ykrt/src/compile/jitc_yk/aot_ir.rs +++ b/ykrt/src/compile/jitc_yk/aot_ir.rs @@ -508,20 +508,42 @@ impl BBlockId { } } -/// Predicates for use in numeric comparisons. +/// Predicates for use in numeric comparisons. These are directly based on [LLVM's `icmp` +/// semantics](https://llvm.org/docs/LangRef.html#icmp-instruction). All quotes below are taken +/// from there. #[deku_derive(DekuRead)] #[derive(Debug, Eq, PartialEq, Clone, Copy)] #[deku(type = "u8")] pub(crate) enum Predicate { + /// "eq: yields true if the operands are equal, false otherwise. No sign + /// interpretation is necessary or performed." Equal = 0, + /// "ne: yields true if the operands are unequal, false otherwise. No sign + /// interpretation is necessary or performed." NotEqual, + /// "ugt: interprets the operands as unsigned values and yields true if op1 is + /// greater than op2." UnsignedGreater, + /// "uge: interprets the operands as unsigned values and yields true if op1 is + /// greater than or equal to op2." UnsignedGreaterEqual, + /// "ule: interprets the operands as unsigned values and yields true if op1 is + /// less than or equal to op2." UnsignedLess, + /// "ule: interprets the operands as unsigned values and yields true if op1 is + /// less than or equal to op2." UnsignedLessEqual, + /// "sgt: interprets the operands as signed values and yields true if op1 is greater + /// than op2." SignedGreater, + /// "sge: interprets the operands as signed values and yields true if op1 is + /// greater than or equal to op2." SignedGreaterEqual, + /// "slt: interprets the operands as signed values and yields true if op1 is less + /// than op2." SignedLess, + /// "sle: interprets the operands as signed values and yields true if op1 is less + /// than or equal to op2." SignedLessEqual, } diff --git a/ykrt/src/compile/jitc_yk/codegen/x64/mod.rs b/ykrt/src/compile/jitc_yk/codegen/x64/mod.rs index dbead3e7d..24b4d0b66 100644 --- a/ykrt/src/compile/jitc_yk/codegen/x64/mod.rs +++ b/ykrt/src/compile/jitc_yk/codegen/x64/mod.rs @@ -404,15 +404,11 @@ impl<'a> Assemble<'a> { } } - fn cg_binop( - &mut self, - iidx: jit_ir::InstIdx, - jit_ir::BinOpInst { lhs, binop, rhs }: &jit_ir::BinOpInst, - ) { - let lhs = lhs.unpack(self.m); - let rhs = rhs.unpack(self.m); + fn cg_binop(&mut self, iidx: jit_ir::InstIdx, inst: &jit_ir::BinOpInst) { + let lhs = inst.lhs(self.m); + let rhs = inst.rhs(self.m); - match binop { + match inst.binop() { BinOp::Add => { let size = lhs.byte_size(self.m); let [lhs_reg, rhs_reg] = self.ra.assign_gp_regs( diff --git a/ykrt/src/compile/jitc_yk/jit_ir/mod.rs b/ykrt/src/compile/jitc_yk/jit_ir/mod.rs index 56c2cb3dc..677c73108 100644 --- a/ykrt/src/compile/jitc_yk/jit_ir/mod.rs +++ b/ykrt/src/compile/jitc_yk/jit_ir/mod.rs @@ -309,7 +309,7 @@ impl Module { /// This function has very few uses and unless you explicitly know why you're using it, you /// should instead use [Self::inst_no_copies] because not handling `Copy` instructions /// correctly leads to undefined behaviour. - fn inst_raw(&self, iidx: InstIdx) -> Inst { + pub(crate) fn inst_raw(&self, iidx: InstIdx) -> Inst { self.insts[usize::from(iidx)] } @@ -1166,7 +1166,7 @@ impl fmt::Display for DisplayableOperand<'_> { /// Note that this struct deliberately does not implement `PartialEq` (or `Eq`): two instances of /// `Const` may represent the same underlying constant, but (because of floats), you as the user /// need to determine what notion of equality you wish to use on a given const. -#[derive(Clone, Debug)] +#[derive(Clone, Debug, PartialEq)] pub(crate) enum Const { Float(TyIdx, f64), /// A constant integer at most 64 bits wide. This can be treated a signed or unsigned integer @@ -1185,15 +1185,6 @@ impl Const { } } - /// If this constant is an integer that can be represented in 64 bits, return it as an `i64`. - pub(crate) fn int_to_u64(&self) -> Option { - match self { - Const::Float(_, _) => None, - Const::Int(_, x) => Some(*x), - Const::Ptr(_) => None, - } - } - /// Create an integer of the same underlying type and with the value `x`. /// /// # Panics @@ -1909,7 +1900,7 @@ pub(crate) struct BinOpInst { /// The left-hand side of the operation. pub(crate) lhs: PackedOperand, /// The operation to perform. - pub(crate) binop: BinOp, + binop: BinOp, /// The right-hand side of the operation. pub(crate) rhs: PackedOperand, } @@ -1923,6 +1914,18 @@ impl BinOpInst { } } + pub(crate) fn lhs(&self, m: &Module) -> Operand { + self.lhs.unpack(m) + } + + pub(crate) fn binop(&self) -> BinOp { + self.binop + } + + pub(crate) fn rhs(&self, m: &Module) -> Operand { + self.rhs.unpack(m) + } + /// Returns the type index of the operands being added. pub(crate) fn tyidx(&self, m: &Module) -> TyIdx { self.lhs.unpack(m).tyidx(m) diff --git a/ykrt/src/compile/jitc_yk/opt/analyse.rs b/ykrt/src/compile/jitc_yk/opt/analyse.rs new file mode 100644 index 000000000..6061adbec --- /dev/null +++ b/ykrt/src/compile/jitc_yk/opt/analyse.rs @@ -0,0 +1,82 @@ +//! Analyse a trace and gradually refine what values we know a previous instruction can produce. + +use super::{ + super::jit_ir::{GuardInst, Inst, InstIdx, Module, Operand, Predicate}, + Value, +}; + +/// Ongoing analysis of a trace: what value can a given instruction in the past produce? +/// +/// Note that the analysis is forward-looking: just because an instruction's `Value` is (say) a +/// `Const` now does not mean it would be valid to assume that at earlier points it is safe to +/// assume it was also a `Const`. +pub(super) struct Analyse { + /// For each instruction, what have we learnt about its [Value] so far? + values: Vec, +} + +impl Analyse { + pub(super) fn new(m: &Module) -> Analyse { + Analyse { + values: vec![Value::Unknown; m.insts_len()], + } + } + + /// Map `op` based on our analysis so far. In some cases this will return `op` unchanged, but + /// in others it may be able to turn what looks like a variable reference into a constant. + pub(super) fn op_map(&mut self, m: &Module, op: Operand) -> Operand { + match op { + Operand::Var(iidx) => match self.values[usize::from(iidx)] { + Value::Unknown => { + // Since we last saw an `ICmp` instruction, we may have gathered new knowledge + // that allows us to turn it into a constant. + if let (iidx, Inst::ICmp(inst)) = m.inst_decopy(iidx) { + let lhs = self.op_map(m, inst.lhs(m)); + let pred = inst.predicate(); + let rhs = self.op_map(m, inst.rhs(m)); + if let (&Operand::Const(lhs_cidx), &Operand::Const(rhs_cidx)) = (&lhs, &rhs) + { + if pred == Predicate::Equal && m.const_(lhs_cidx) == m.const_(rhs_cidx) + { + self.values[usize::from(iidx)] = Value::Const(m.true_constidx()); + return Operand::Const(m.true_constidx()); + } + } + } + op + } + Value::Const(cidx) => Operand::Const(cidx), + }, + Operand::Const(_) => op, + } + } + + /// Update our idea of what value the instruction at `iidx` can produce. + pub(super) fn set_value(&mut self, iidx: InstIdx, v: Value) { + self.values[usize::from(iidx)] = v; + } + + /// Use the guard `inst` to update our knowledge about the variable used as its condition. + pub(super) fn guard(&mut self, m: &Module, inst: GuardInst) { + if let Operand::Var(iidx) = inst.cond(m) { + if let (_, Inst::ICmp(inst)) = m.inst_decopy(iidx) { + let lhs = self.op_map(m, inst.lhs(m)); + let pred = inst.predicate(); + let rhs = self.op_map(m, inst.rhs(m)); + match (&lhs, &rhs) { + (&Operand::Const(_), &Operand::Const(_)) => { + // This will have been handled by icmp/guard optimisations. + unreachable!(); + } + (&Operand::Var(iidx), &Operand::Const(cidx)) + | (&Operand::Const(cidx), &Operand::Var(iidx)) => { + if pred == Predicate::Equal { + self.set_value(iidx, Value::Const(cidx)); + } + } + (&Operand::Var(_), &Operand::Var(_)) => (), + } + } + } + } +} diff --git a/ykrt/src/compile/jitc_yk/opt/mod.rs b/ykrt/src/compile/jitc_yk/opt/mod.rs index bb350790f..91d1bbff9 100644 --- a/ykrt/src/compile/jitc_yk/opt/mod.rs +++ b/ykrt/src/compile/jitc_yk/opt/mod.rs @@ -1,13 +1,493 @@ -use super::jit_ir::Module; +// A trace IR optimiser. +// +// The optimiser works in a single forward pass (well, it also does a single backwards pass at the +// end too, but that's only because we can't yet do backwards code generation). As it progresses +// through a trace, it both mutates the trace IR directly and also refines its idea about what +// value an instruction might produce. These two actions are subtly different: mutation is done +// in this module; the refinement of values in the [Analyse] module. + +use super::jit_ir::{ + BinOp, BinOpInst, Const, ConstIdx, ICmpInst, Inst, InstIdx, Module, Operand, Predicate, Ty, +}; use crate::compile::CompilationError; -mod simple; +mod analyse; + +use analyse::Analyse; + +struct Opt { + m: Module, + an: Analyse, +} + +impl Opt { + fn new(m: Module) -> Self { + let an = Analyse::new(&m); + Self { m, an } + } + + fn opt(mut self) -> Result { + for iidx in self.m.iter_all_inst_idxs() { + match self.m.inst_raw(iidx) { + #[cfg(test)] + Inst::BlackBox(_) => (), + Inst::Const(cidx) => self.an.set_value(iidx, Value::Const(cidx)), + Inst::BinOp(x) => match x.binop() { + BinOp::Add => (), + BinOp::Mul => match ( + self.an.op_map(&self.m, x.lhs(&self.m)), + self.an.op_map(&self.m, x.rhs(&self.m)), + ) { + (Operand::Const(cidx), Operand::Var(copy_iidx)) + | (Operand::Var(copy_iidx), Operand::Const(cidx)) => { + match self.m.const_(cidx) { + Const::Int(_, 0) => { + // Replace `x * 0` with `0`. + self.m.replace(iidx, Inst::Const(cidx)); + } + Const::Int(_, 1) => { + // Replace `x * 1` with `x`. + self.m.replace(iidx, Inst::Copy(copy_iidx)); + } + Const::Int(ty_idx, x) if x.is_power_of_two() => { + // Replace `x * y` with `x << ...`. + let shl = u64::from(x.ilog2()); + let shl_op = Operand::Const( + self.m.insert_const(Const::Int(*ty_idx, shl))?, + ); + let new_inst = + BinOpInst::new(Operand::Var(copy_iidx), BinOp::Shl, shl_op) + .into(); + self.m.replace(iidx, new_inst); + } + _ => (), + } + } + (Operand::Const(lhs_cidx), Operand::Const(rhs_cidx)) => { + let lhs_c = self.m.const_(lhs_cidx); + let rhs_c = self.m.const_(rhs_cidx); + match (lhs_c, rhs_c) { + (Const::Int(lhs_ty, lhs_v), Const::Int(rhs_ty, rhs_v)) => { + debug_assert_eq!(lhs_ty, rhs_ty); + let Ty::Integer(bits) = self.m.type_(*lhs_ty) else { + panic!() + }; + let mul = lhs_v.sign_extend(*bits) * rhs_v.sign_extend(*bits); + let trun = mul.truncate(*bits); + let cidx = self.m.insert_const(lhs_c.u64_to_int(trun))?; + self.m.replace(iidx, Inst::Const(cidx)); + } + _ => todo!(), + } + } + (Operand::Var(_), Operand::Var(_)) => (), + }, + _ => (), + }, + Inst::ICmp(x) => { + self.icmp(iidx, x); + } + Inst::Guard(x) => { + if let Operand::Const(_) = self.an.op_map(&self.m, x.cond(&self.m)) { + // A guard that references a constant is, by definition, not needed and + // doesn't affect future analyses. + self.m.replace(iidx, Inst::Tombstone); + } else { + self.an.guard(&self.m, x); + } + } + _ => (), + } + } + // FIXME: When code generation supports backwards register allocation, we won't need to + // explicitly perform dead code elimination and this function can be made `#[cfg(test)]` only. + self.m.dead_code_elimination(); + Ok(self.m) + } + + /// Optimise an [ICmpInst]. + fn icmp(&mut self, iidx: InstIdx, inst: ICmpInst) { + let lhs = self.an.op_map(&self.m, inst.lhs(&self.m)); + let pred = inst.predicate(); + let rhs = self.an.op_map(&self.m, inst.rhs(&self.m)); + match (&lhs, &rhs) { + (&Operand::Const(lhs_cidx), &Operand::Const(rhs_cidx)) => { + self.icmp_both_const(iidx, lhs_cidx, pred, rhs_cidx) + } + (&Operand::Var(_), &Operand::Const(_)) | (&Operand::Const(_), &Operand::Var(_)) => (), + (&Operand::Var(_), &Operand::Var(_)) => (), + } + } + + /// Optimise an `ICmp` if both sides are constants. It is required that [Opt::op_map] has been + /// called on both `lhs` and `rhs` to obtain the `ConstIdx`s. + fn icmp_both_const(&mut self, iidx: InstIdx, lhs: ConstIdx, pred: Predicate, rhs: ConstIdx) { + let lhs_c = self.m.const_(lhs); + let rhs_c = self.m.const_(rhs); + match (lhs_c, rhs_c) { + (Const::Float(..), Const::Float(..)) => (), + (Const::Int(lhs_tyidx, x), Const::Int(rhs_tyidx, y)) => { + debug_assert_eq!(lhs_tyidx, rhs_tyidx); + // Constant fold comparisons of simple integers. + let x = *x; + let y = *y; + let r = match pred { + Predicate::Equal => x == y, + Predicate::NotEqual => x != y, + Predicate::UnsignedGreater => x > y, + Predicate::UnsignedGreaterEqual => x >= y, + Predicate::UnsignedLess => x < y, + Predicate::UnsignedLessEqual => x <= y, + Predicate::SignedGreater => (x as i64) > (y as i64), + Predicate::SignedGreaterEqual => (x as i64) >= (y as i64), + Predicate::SignedLess => (x as i64) < (y as i64), + Predicate::SignedLessEqual => (x as i64) <= (y as i64), + }; + + self.m.replace( + iidx, + Inst::Const(if r { + self.m.true_constidx() + } else { + self.m.false_constidx() + }), + ); + } + (Const::Ptr(x), Const::Ptr(y)) => { + // Constant fold comparisons of pointers. + let x = *x; + let y = *y; + let r = match pred { + Predicate::Equal => x == y, + Predicate::NotEqual => x != y, + Predicate::UnsignedGreater => x > y, + Predicate::UnsignedGreaterEqual => x >= y, + Predicate::UnsignedLess => x < y, + Predicate::UnsignedLessEqual => x <= y, + Predicate::SignedGreater => (x as i64) > (y as i64), + Predicate::SignedGreaterEqual => (x as i64) >= (y as i64), + Predicate::SignedLess => (x as i64) < (y as i64), + Predicate::SignedLessEqual => (x as i64) <= (y as i64), + }; + + self.m.replace( + iidx, + Inst::Const(if r { + self.m.true_constidx() + } else { + self.m.false_constidx() + }), + ); + } + _ => unreachable!(), + } + } +} + +trait SignExtend { + fn sign_extend(&self, bits: u32) -> Self; +} + +impl SignExtend for u64 { + fn sign_extend(&self, bits: u32) -> Self { + debug_assert!( + bits > 0 && bits <= Self::BITS, + "{bits} outside range 1..={}", + Self::BITS + ); + let shift = Self::BITS - bits; + (*self << shift) >> shift + } +} + +trait Truncate { + fn truncate(&self, bits: u32) -> Self; +} + +impl Truncate for u64 { + fn truncate(&self, bits: u32) -> Self { + debug_assert!( + bits > 0 && bits <= Self::BITS, + "{bits} outside range 1..={}", + Self::BITS + ); + *self & ((1 as Self).wrapping_shl(bits) - 1) + } +} + +#[derive(Clone, Debug)] +enum Value { + Unknown, + Const(ConstIdx), +} /// Create JIT IR from the (`aot_mod`, `ta_iter`) tuple. pub(super) fn opt(m: Module) -> Result { - let mut m = simple::simple(m)?; - // FIXME: When code generation supports backwards register allocation, we won't need to - // explicitly perform dead code elimination and this function can be made `#[cfg(test)]` only. - m.dead_code_elimination(); - Ok(m) + Opt::new(m).opt() +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn opt_const_guard() { + Module::assert_ir_transform_eq( + " + entry: + %0: i1 = load_ti 0 + guard false, 0i1, [%0] + ", + |m| opt(m).unwrap(), + " + ... + entry: + ", + ); + } + + #[test] + fn opt_const_guard_indirect() { + Module::assert_ir_transform_eq( + " + entry: + %0: i1 = eq 0i8, 0i8 + guard true, %0, [] + %1: i1 = eq 0i8, 1i8 + guard false, %1, [%0] + ", + |m| opt(m).unwrap(), + " + ... + entry: + ", + ); + } + + #[test] + fn opt_const_guard_chain() { + Module::assert_ir_transform_eq( + " + entry: + %0: i8 = load_ti 0 + %1: i8 = mul %0, 0i8 + %2: i1 = eq %1, 0i8 + guard true, %2, [%0, %1] + black_box %0 + ", + |m| opt(m).unwrap(), + " + ... + entry: + %0: i8 = load_ti ... + black_box %0 + ", + ); + } + + #[test] + fn opt_mul_zero() { + Module::assert_ir_transform_eq( + " + entry: + %0: i8 = load_ti 0 + %1: i8 = load_ti 1 + %2: i8 = mul %0, 0i8 + %3: i8 = add %1, %2 + %4: i8 = mul 0i8, %0 + %5: i8 = add %1, %2 + black_box %3 + black_box %5 + ", + |m| opt(m).unwrap(), + " + ... + entry: + %1: i8 = load_ti ... + %3: i8 = add %1, 0i8 + %5: i8 = add %1, 0i8 + black_box %3 + black_box %5 + ", + ); + } + + #[test] + fn opt_mul_one() { + Module::assert_ir_transform_eq( + " + entry: + %0: i8 = load_ti 0 + %1: i8 = load_ti 1 + %2: i8 = mul %0, 1i8 + %3: i8 = add %1, %2 + %4: i8 = mul 1i8, %0 + %5: i8 = add %1, %2 + black_box %3 + black_box %5 + ", + |m| opt(m).unwrap(), + " + ... + entry: + %0: i8 = load_ti ... + %1: i8 = load_ti ... + %3: i8 = add %1, %0 + %5: i8 = add %1, %0 + black_box %3 + black_box %5 + ", + ); + } + + #[test] + fn opt_mul_const() { + Module::assert_ir_transform_eq( + " + entry: + %0: i8 = load_ti 0 + %1: i8 = mul %0, 0i8 + %2: i8 = mul %0, 0i8 + %3: i8 = mul %1, %2 + black_box %3 + ", + |m| opt(m).unwrap(), + " + ... + entry: + black_box 0i8 + ", + ); + } + + #[test] + fn opt_mul_shl() { + Module::assert_ir_transform_eq( + " + entry: + %0: i64 = load_ti 0 + %1: i64 = mul %0, 2i64 + %2: i64 = mul %0, 4i64 + %3: i64 = mul %0, 4611686018427387904i64 + %4: i64 = mul %0, 9223372036854775807i64 + %5: i64 = mul %0, 12i64 + black_box %1 + black_box %2 + black_box %3 + black_box %4 + black_box %5 + ", + |m| opt(m).unwrap(), + " + ... + entry: + %0: i64 = load_ti ... + %1: i64 = shl %0, 1i64 + %2: i64 = shl %0, 2i64 + %3: i64 = shl %0, 62i64 + %4: i64 = mul %0, 9223372036854775807i64 + %5: i64 = mul %0, 12i64 + black_box ... + ... + ", + ); + } + + #[test] + fn opt_icmp_const() { + Module::assert_ir_transform_eq( + " + entry: + %0: i1 = eq 0i8, 0i8 + %1: i1 = eq 0i8, 1i8 + %2: i1 = ne 0i8, 0i8 + %3: i1 = ne 0i8, 1i8 + %4: i1 = ugt 0i8, 0i8 + %5: i1 = ugt 0i8, 1i8 + %6: i1 = ugt 1i8, 0i8 + %7: i1 = uge 0i8, 0i8 + %8: i1 = uge 0i8, 1i8 + %9: i1 = uge 1i8, 0i8 + %10: i1 = ult 0i8, 0i8 + %11: i1 = ult 0i8, 1i8 + %12: i1 = ult 1i8, 0i8 + %13: i1 = ule 0i8, 0i8 + %14: i1 = ule 0i8, 1i8 + %15: i1 = ule 1i8, 0i8 + %16: i1 = sgt 0i8, 0i8 + %17: i1 = sgt 0i8, -1i8 + %18: i1 = sgt -1i8, 0i8 + %19: i1 = sge 0i8, 0i8 + %20: i1 = sge 0i8, -1i8 + %21: i1 = sge -1i8, 0i8 + %22: i1 = slt 0i8, 0i8 + %23: i1 = slt 0i8, -1i8 + %24: i1 = slt -1i8, 0i8 + %25: i1 = sle 0i8, 0i8 + %26: i1 = sle 0i8, -1i8 + %27: i1 = sle -1i8, 0i8 + black_box %0 + black_box %1 + black_box %2 + black_box %3 + black_box %4 + black_box %5 + black_box %6 + black_box %7 + black_box %8 + black_box %9 + black_box %10 + black_box %11 + black_box %12 + black_box %13 + black_box %14 + black_box %15 + black_box %16 + black_box %17 + black_box %18 + black_box %19 + black_box %20 + black_box %21 + black_box %22 + black_box %23 + black_box %24 + black_box %25 + black_box %26 + black_box %27 + ", + |m| opt(m).unwrap(), + " + ... + entry: + black_box 1i1 + black_box 0i1 + black_box 0i1 + black_box 1i1 + black_box 0i1 + black_box 0i1 + black_box 1i1 + black_box 1i1 + black_box 0i1 + black_box 1i1 + black_box 0i1 + black_box 1i1 + black_box 0i1 + black_box 1i1 + black_box 1i1 + black_box 0i1 + black_box 0i1 + black_box 1i1 + black_box 0i1 + black_box 1i1 + black_box 1i1 + black_box 0i1 + black_box 0i1 + black_box 0i1 + black_box 1i1 + black_box 1i1 + black_box 0i1 + black_box 1i1 + ", + ); + } } diff --git a/ykrt/src/compile/jitc_yk/opt/simple.rs b/ykrt/src/compile/jitc_yk/opt/simple.rs deleted file mode 100644 index b50022d6d..000000000 --- a/ykrt/src/compile/jitc_yk/opt/simple.rs +++ /dev/null @@ -1,429 +0,0 @@ -//! Simple, local optimisations. -//! -//! These include strength reductions and other optimisations that can be performed with little -//! analysis. - -use crate::compile::{ - jitc_yk::jit_ir::{ - BinOp, BinOpInst, GuardInst, ICmpInst, Inst, InstIdx, Module, Operand, PackedOperand, - Predicate, - }, - CompilationError, -}; - -pub(super) fn simple(mut m: Module) -> Result { - for iidx in m.iter_all_inst_idxs() { - let (iidx, inst) = m.inst_decopy(iidx); - match inst { - Inst::BinOp(BinOpInst { - lhs, - binop: BinOp::Mul, - rhs, - }) => opt_mul(&mut m, iidx, lhs, rhs)?, - Inst::Guard(x) => opt_guard(&mut m, iidx, x)?, - Inst::ICmp(x) => opt_icmp(&mut m, iidx, x)?, - _ => (), - } - } - Ok(m) -} - -fn opt_mul( - m: &mut Module, - iidx: InstIdx, - lhs: PackedOperand, - rhs: PackedOperand, -) -> Result<(), CompilationError> { - match (lhs.unpack(m), rhs.unpack(m)) { - (Operand::Var(mut mul_inst), Operand::Const(mul_const)) - | (Operand::Const(mul_const), Operand::Var(mut mul_inst)) => { - let old_const = m.const_(mul_const); - if let Some(old_val) = old_const.int_to_u64() { - let mut new_val = old_val; - // If we've got `%2: mul %1, xi8` then see if `%1` is of the form `mul %0, yi8`: if so - // we've got a chain that's `%2: %0*x*y`. We can thus "skip" the intermediate `mul` - // when calculating the constant we're going to optimise. - if let Inst::BinOp(BinOpInst { - lhs: chain_lhs, - binop: BinOp::Mul, - rhs: chain_rhs, - }) = m.inst_no_copies(mul_inst) - { - if let (Operand::Var(chain_mul_inst), Operand::Const(chain_mul_const)) - | (Operand::Const(chain_mul_const), Operand::Var(chain_mul_inst)) = - (chain_lhs.unpack(m), chain_rhs.unpack(m)) - { - if let Some(y) = m.const_(chain_mul_const).int_to_u64() { - mul_inst = chain_mul_inst; - new_val = old_val * y; - } - } - } - - if new_val == 0 { - // Replace `x * 0` with `0`. - let cidx = m.insert_const(old_const.u64_to_int(0))?; - m.replace(iidx, Inst::Const(cidx)); - } else if new_val == 1 { - // Replace `x * 1` with `x`. - m.replace(iidx, Inst::Copy(mul_inst)); - } else if new_val.is_power_of_two() { - // Replace `x * y` with `x << ...`. - let shl = u64::from(new_val.ilog2()); - let new_const = Operand::Const(m.insert_const(old_const.u64_to_int(shl))?); - let new_inst = - BinOpInst::new(Operand::Var(mul_inst), BinOp::Shl, new_const).into(); - m.replace(iidx, new_inst); - } else if new_val != old_val { - let new_const = Operand::Const(m.insert_const(old_const.u64_to_int(new_val))?); - let new_inst = - BinOpInst::new(Operand::Var(mul_inst), BinOp::Mul, new_const).into(); - m.replace(iidx, new_inst); - } - } - } - (Operand::Const(x), Operand::Const(y)) => { - // Constant fold the unsigned multiplication of two constants. - let x = m.const_(x); - let y = m.const_(y); - // If `x_val * y_val` overflows, we're fine with the UB, as the interpreter - // author is at fault. - let new_val = x.int_to_u64().unwrap() * y.int_to_u64().unwrap(); - let new_const = m.insert_const(x.u64_to_int(new_val))?; - m.replace(iidx, Inst::Const(new_const)); - } - (Operand::Var(_), Operand::Var(_)) => (), - } - Ok(()) -} - -fn opt_icmp( - m: &mut Module, - iidx: InstIdx, - ICmpInst { lhs, pred, rhs }: ICmpInst, -) -> Result<(), CompilationError> { - if let (Operand::Const(x), Operand::Const(y)) = (lhs.unpack(m), rhs.unpack(m)) { - if let (Some(x), Some(y)) = (m.const_(x).int_to_u64(), m.const_(y).int_to_u64()) { - // Constant fold comparisons of simple integers. Note that we have to follow the - // LLVM semantics carefully. The quotes in the `match` below are from - // https://llvm.org/docs/LangRef.html#icmp-instruction. - let r = match pred { - // "eq: yields true if the operands are equal, false otherwise. No sign - // interpretation is necessary or performed." - Predicate::Equal => x == y, - // "ne: yields true if the operands are unequal, false otherwise. No sign - // interpretation is necessary or performed." - Predicate::NotEqual => x != y, - // "ugt: interprets the operands as unsigned values and yields true if op1 is - // greater than op2." - Predicate::UnsignedGreater => x > y, - // "uge: interprets the operands as unsigned values and yields true if op1 is - // greater than or equal to op2." - Predicate::UnsignedGreaterEqual => x >= y, - // "ult: interprets the operands as unsigned values and yields true if op1 is - // less than op2." - Predicate::UnsignedLess => x < y, - // "ule: interprets the operands as unsigned values and yields true if op1 is - // less than or equal to op2." - Predicate::UnsignedLessEqual => x <= y, - // "interprets the operands as signed values and yields true if op1 is greater - // than op2." - Predicate::SignedGreater => (x as i64) > (y as i64), - // "sge: interprets the operands as signed values and yields true if op1 is - // greater than or equal to op2." - Predicate::SignedGreaterEqual => (x as i64) >= (y as i64), - // "slt: interprets the operands as signed values and yields true if op1 is less - // than op2." - Predicate::SignedLess => (x as i64) < (y as i64), - // "sle: interprets the operands as signed values and yields true if op1 is less - // than or equal to op2." - Predicate::SignedLessEqual => (x as i64) <= (y as i64), - }; - - if r { - m.replace(iidx, Inst::Const(m.true_constidx())); - } else { - m.replace(iidx, Inst::Const(m.false_constidx())); - } - } - } - - Ok(()) -} - -fn opt_guard( - m: &mut Module, - iidx: InstIdx, - GuardInst { - cond, - expect: _, - gidx: _, - }: GuardInst, -) -> Result<(), CompilationError> { - if let Operand::Const(_) = cond.unpack(m) { - // A guard that references a constant is, by definition, not useful. - m.replace(iidx, Inst::Tombstone); - } - Ok(()) -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn opt_mul_zero() { - Module::assert_ir_transform_eq( - " - entry: - %0: i8 = load_ti 0 - %1: i8 = load_ti 1 - %2: i8 = mul %0, 0i8 - %3: i8 = add %1, %2 - %4: i8 = mul 0i8, %0 - %5: i8 = add %1, %2 - ", - |m| simple(m).unwrap(), - " - ... - entry: - %0: i8 = load_ti ... - %1: i8 = load_ti ... - %3: i8 = add %1, 0i8 - %5: i8 = add %1, 0i8 - ", - ); - } - - #[test] - fn opt_mul_one() { - Module::assert_ir_transform_eq( - " - entry: - %0: i8 = load_ti 0 - %1: i8 = load_ti 1 - %2: i8 = mul %0, 1i8 - %3: i8 = add %1, %2 - %4: i8 = mul 1i8, %0 - %5: i8 = add %1, %2 - ", - |m| simple(m).unwrap(), - " - ... - entry: - %0: i8 = load_ti ... - %1: i8 = load_ti ... - %3: i8 = add %1, %0 - %5: i8 = add %1, %0 - ", - ); - } - - #[test] - fn opt_mul_const() { - Module::assert_ir_transform_eq( - " - entry: - %0: i8 = load_ti 0 - %1: i8 = mul %0, 0i8 - %2: i8 = mul %0, 0i8 - %3: i8 = mul %1, %2 - black_box %3 - ", - |m| simple(m).unwrap(), - " - ... - entry: - %0: i8 = load_ti ... - black_box 0i8 - ", - ); - } - - #[test] - fn opt_mul_chain() { - Module::assert_ir_transform_eq( - " - entry: - %0: i8 = load_ti 0 - %1: i8 = mul %0, 3i8 - %2: i8 = mul %1, 4i8 - %3: i8 = mul %2, 5i8 - black_box %3 - ", - |m| simple(m).unwrap(), - " - ... - entry: - %0: i8 = load_ti ... - %1: i8 = mul %0, 3i8 - %2: i8 = mul %0, 12i8 - %3: i8 = mul %0, 60i8 - black_box %3 - ", - ); - } - - #[test] - fn opt_mul_shl() { - Module::assert_ir_transform_eq( - " - entry: - %0: i64 = load_ti 0 - %1: i64 = mul %0, 2i64 - %2: i64 = mul %0, 4i64 - %3: i64 = mul %0, 4611686018427387904i64 - %4: i64 = mul %0, 9223372036854775807i64 - %5: i64 = mul %0, 12i64 - ", - |m| simple(m).unwrap(), - " - ... - entry: - %0: i64 = load_ti ... - %1: i64 = shl %0, 1i64 - %2: i64 = shl %0, 2i64 - %3: i64 = shl %0, 62i64 - %4: i64 = mul %0, 9223372036854775807i64 - %5: i64 = mul %0, 12i64 - ", - ); - } - - #[test] - fn opt_icmp_const() { - Module::assert_ir_transform_eq( - " - entry: - %0: i1 = eq 0i8, 0i8 - %1: i1 = eq 0i8, 1i8 - %2: i1 = ne 0i8, 0i8 - %3: i1 = ne 0i8, 1i8 - %4: i1 = ugt 0i8, 0i8 - %5: i1 = ugt 0i8, 1i8 - %6: i1 = ugt 1i8, 0i8 - %7: i1 = uge 0i8, 0i8 - %8: i1 = uge 0i8, 1i8 - %9: i1 = uge 1i8, 0i8 - %10: i1 = ult 0i8, 0i8 - %11: i1 = ult 0i8, 1i8 - %12: i1 = ult 1i8, 0i8 - %13: i1 = ule 0i8, 0i8 - %14: i1 = ule 0i8, 1i8 - %15: i1 = ule 1i8, 0i8 - %16: i1 = sgt 0i8, 0i8 - %17: i1 = sgt 0i8, -1i8 - %18: i1 = sgt -1i8, 0i8 - %19: i1 = sge 0i8, 0i8 - %20: i1 = sge 0i8, -1i8 - %21: i1 = sge -1i8, 0i8 - %22: i1 = slt 0i8, 0i8 - %23: i1 = slt 0i8, -1i8 - %24: i1 = slt -1i8, 0i8 - %25: i1 = sle 0i8, 0i8 - %26: i1 = sle 0i8, -1i8 - %27: i1 = sle -1i8, 0i8 - black_box %0 - black_box %1 - black_box %2 - black_box %3 - black_box %4 - black_box %5 - black_box %6 - black_box %7 - black_box %8 - black_box %9 - black_box %10 - black_box %11 - black_box %12 - black_box %13 - black_box %14 - black_box %15 - black_box %16 - black_box %17 - black_box %18 - black_box %19 - black_box %20 - black_box %21 - black_box %22 - black_box %23 - black_box %24 - black_box %25 - black_box %26 - black_box %27 - ", - |m| simple(m).unwrap(), - " - ... - entry: - black_box 1i1 - black_box 0i1 - black_box 0i1 - black_box 1i1 - black_box 0i1 - black_box 0i1 - black_box 1i1 - black_box 1i1 - black_box 0i1 - black_box 1i1 - black_box 0i1 - black_box 1i1 - black_box 0i1 - black_box 1i1 - black_box 1i1 - black_box 0i1 - black_box 0i1 - black_box 1i1 - black_box 0i1 - black_box 1i1 - black_box 1i1 - black_box 0i1 - black_box 0i1 - black_box 0i1 - black_box 1i1 - black_box 1i1 - black_box 0i1 - black_box 1i1 - ", - ); - } - - #[test] - fn opt_const_guard() { - Module::assert_ir_transform_eq( - " - entry: - %0: i1 = eq 0i8, 0i8 - guard true, %0, [] - %1: i1 = eq 0i8, 1i8 - guard false, %1, [%0] - ", - |m| simple(m).unwrap(), - " - ... - entry: - ", - ); - } - - #[test] - fn opt_const_guard_chain() { - Module::assert_ir_transform_eq( - " - entry: - %0: i8 = load_ti 0 - %1: i8 = mul %0, 0i8 - %2: i1 = eq %1, 0i8 - guard true, %2, [%0, %1] - black_box %0 - ", - |m| simple(m).unwrap(), - " - ... - entry: - %0: i8 = load_ti ... - black_box %0 - ", - ); - } -}