diff --git a/ykrt/src/compile/jitc_yk/aot_ir.rs b/ykrt/src/compile/jitc_yk/aot_ir.rs index 00a08ec2a..3c6cb8d55 100644 --- a/ykrt/src/compile/jitc_yk/aot_ir.rs +++ b/ykrt/src/compile/jitc_yk/aot_ir.rs @@ -508,20 +508,42 @@ impl BBlockId { } } -/// Predicates for use in numeric comparisons. +/// Predicates for use in numeric comparisons. These are directly based on [LLVM's `icmp` +/// semantics](https://llvm.org/docs/LangRef.html#icmp-instruction). All quotes below are taken +/// from there. #[deku_derive(DekuRead)] #[derive(Debug, Eq, PartialEq, Clone, Copy)] #[deku(type = "u8")] pub(crate) enum Predicate { + /// "eq: yields true if the operands are equal, false otherwise. No sign + /// interpretation is necessary or performed." Equal = 0, + /// "ne: yields true if the operands are unequal, false otherwise. No sign + /// interpretation is necessary or performed." NotEqual, + /// "ugt: interprets the operands as unsigned values and yields true if op1 is + /// greater than op2." UnsignedGreater, + /// "uge: interprets the operands as unsigned values and yields true if op1 is + /// greater than or equal to op2." UnsignedGreaterEqual, + /// "ule: interprets the operands as unsigned values and yields true if op1 is + /// less than or equal to op2." UnsignedLess, + /// "ule: interprets the operands as unsigned values and yields true if op1 is + /// less than or equal to op2." UnsignedLessEqual, + /// "sgt: interprets the operands as signed values and yields true if op1 is greater + /// than op2." SignedGreater, + /// "sge: interprets the operands as signed values and yields true if op1 is + /// greater than or equal to op2." SignedGreaterEqual, + /// "slt: interprets the operands as signed values and yields true if op1 is less + /// than op2." SignedLess, + /// "sle: interprets the operands as signed values and yields true if op1 is less + /// than or equal to op2." SignedLessEqual, } diff --git a/ykrt/src/compile/jitc_yk/codegen/x64/mod.rs b/ykrt/src/compile/jitc_yk/codegen/x64/mod.rs index dbead3e7d..24b4d0b66 100644 --- a/ykrt/src/compile/jitc_yk/codegen/x64/mod.rs +++ b/ykrt/src/compile/jitc_yk/codegen/x64/mod.rs @@ -404,15 +404,11 @@ impl<'a> Assemble<'a> { } } - fn cg_binop( - &mut self, - iidx: jit_ir::InstIdx, - jit_ir::BinOpInst { lhs, binop, rhs }: &jit_ir::BinOpInst, - ) { - let lhs = lhs.unpack(self.m); - let rhs = rhs.unpack(self.m); + fn cg_binop(&mut self, iidx: jit_ir::InstIdx, inst: &jit_ir::BinOpInst) { + let lhs = inst.lhs(self.m); + let rhs = inst.rhs(self.m); - match binop { + match inst.binop() { BinOp::Add => { let size = lhs.byte_size(self.m); let [lhs_reg, rhs_reg] = self.ra.assign_gp_regs( diff --git a/ykrt/src/compile/jitc_yk/jit_ir/mod.rs b/ykrt/src/compile/jitc_yk/jit_ir/mod.rs index 56c2cb3dc..677c73108 100644 --- a/ykrt/src/compile/jitc_yk/jit_ir/mod.rs +++ b/ykrt/src/compile/jitc_yk/jit_ir/mod.rs @@ -309,7 +309,7 @@ impl Module { /// This function has very few uses and unless you explicitly know why you're using it, you /// should instead use [Self::inst_no_copies] because not handling `Copy` instructions /// correctly leads to undefined behaviour. - fn inst_raw(&self, iidx: InstIdx) -> Inst { + pub(crate) fn inst_raw(&self, iidx: InstIdx) -> Inst { self.insts[usize::from(iidx)] } @@ -1166,7 +1166,7 @@ impl fmt::Display for DisplayableOperand<'_> { /// Note that this struct deliberately does not implement `PartialEq` (or `Eq`): two instances of /// `Const` may represent the same underlying constant, but (because of floats), you as the user /// need to determine what notion of equality you wish to use on a given const. -#[derive(Clone, Debug)] +#[derive(Clone, Debug, PartialEq)] pub(crate) enum Const { Float(TyIdx, f64), /// A constant integer at most 64 bits wide. This can be treated a signed or unsigned integer @@ -1185,15 +1185,6 @@ impl Const { } } - /// If this constant is an integer that can be represented in 64 bits, return it as an `i64`. - pub(crate) fn int_to_u64(&self) -> Option { - match self { - Const::Float(_, _) => None, - Const::Int(_, x) => Some(*x), - Const::Ptr(_) => None, - } - } - /// Create an integer of the same underlying type and with the value `x`. /// /// # Panics @@ -1909,7 +1900,7 @@ pub(crate) struct BinOpInst { /// The left-hand side of the operation. pub(crate) lhs: PackedOperand, /// The operation to perform. - pub(crate) binop: BinOp, + binop: BinOp, /// The right-hand side of the operation. pub(crate) rhs: PackedOperand, } @@ -1923,6 +1914,18 @@ impl BinOpInst { } } + pub(crate) fn lhs(&self, m: &Module) -> Operand { + self.lhs.unpack(m) + } + + pub(crate) fn binop(&self) -> BinOp { + self.binop + } + + pub(crate) fn rhs(&self, m: &Module) -> Operand { + self.rhs.unpack(m) + } + /// Returns the type index of the operands being added. pub(crate) fn tyidx(&self, m: &Module) -> TyIdx { self.lhs.unpack(m).tyidx(m) diff --git a/ykrt/src/compile/jitc_yk/opt/analyse.rs b/ykrt/src/compile/jitc_yk/opt/analyse.rs new file mode 100644 index 000000000..6061adbec --- /dev/null +++ b/ykrt/src/compile/jitc_yk/opt/analyse.rs @@ -0,0 +1,82 @@ +//! Analyse a trace and gradually refine what values we know a previous instruction can produce. + +use super::{ + super::jit_ir::{GuardInst, Inst, InstIdx, Module, Operand, Predicate}, + Value, +}; + +/// Ongoing analysis of a trace: what value can a given instruction in the past produce? +/// +/// Note that the analysis is forward-looking: just because an instruction's `Value` is (say) a +/// `Const` now does not mean it would be valid to assume that at earlier points it is safe to +/// assume it was also a `Const`. +pub(super) struct Analyse { + /// For each instruction, what have we learnt about its [Value] so far? + values: Vec, +} + +impl Analyse { + pub(super) fn new(m: &Module) -> Analyse { + Analyse { + values: vec![Value::Unknown; m.insts_len()], + } + } + + /// Map `op` based on our analysis so far. In some cases this will return `op` unchanged, but + /// in others it may be able to turn what looks like a variable reference into a constant. + pub(super) fn op_map(&mut self, m: &Module, op: Operand) -> Operand { + match op { + Operand::Var(iidx) => match self.values[usize::from(iidx)] { + Value::Unknown => { + // Since we last saw an `ICmp` instruction, we may have gathered new knowledge + // that allows us to turn it into a constant. + if let (iidx, Inst::ICmp(inst)) = m.inst_decopy(iidx) { + let lhs = self.op_map(m, inst.lhs(m)); + let pred = inst.predicate(); + let rhs = self.op_map(m, inst.rhs(m)); + if let (&Operand::Const(lhs_cidx), &Operand::Const(rhs_cidx)) = (&lhs, &rhs) + { + if pred == Predicate::Equal && m.const_(lhs_cidx) == m.const_(rhs_cidx) + { + self.values[usize::from(iidx)] = Value::Const(m.true_constidx()); + return Operand::Const(m.true_constidx()); + } + } + } + op + } + Value::Const(cidx) => Operand::Const(cidx), + }, + Operand::Const(_) => op, + } + } + + /// Update our idea of what value the instruction at `iidx` can produce. + pub(super) fn set_value(&mut self, iidx: InstIdx, v: Value) { + self.values[usize::from(iidx)] = v; + } + + /// Use the guard `inst` to update our knowledge about the variable used as its condition. + pub(super) fn guard(&mut self, m: &Module, inst: GuardInst) { + if let Operand::Var(iidx) = inst.cond(m) { + if let (_, Inst::ICmp(inst)) = m.inst_decopy(iidx) { + let lhs = self.op_map(m, inst.lhs(m)); + let pred = inst.predicate(); + let rhs = self.op_map(m, inst.rhs(m)); + match (&lhs, &rhs) { + (&Operand::Const(_), &Operand::Const(_)) => { + // This will have been handled by icmp/guard optimisations. + unreachable!(); + } + (&Operand::Var(iidx), &Operand::Const(cidx)) + | (&Operand::Const(cidx), &Operand::Var(iidx)) => { + if pred == Predicate::Equal { + self.set_value(iidx, Value::Const(cidx)); + } + } + (&Operand::Var(_), &Operand::Var(_)) => (), + } + } + } + } +} diff --git a/ykrt/src/compile/jitc_yk/opt/mod.rs b/ykrt/src/compile/jitc_yk/opt/mod.rs index bb350790f..91d1bbff9 100644 --- a/ykrt/src/compile/jitc_yk/opt/mod.rs +++ b/ykrt/src/compile/jitc_yk/opt/mod.rs @@ -1,13 +1,493 @@ -use super::jit_ir::Module; +// A trace IR optimiser. +// +// The optimiser works in a single forward pass (well, it also does a single backwards pass at the +// end too, but that's only because we can't yet do backwards code generation). As it progresses +// through a trace, it both mutates the trace IR directly and also refines its idea about what +// value an instruction might produce. These two actions are subtly different: mutation is done +// in this module; the refinement of values in the [Analyse] module. + +use super::jit_ir::{ + BinOp, BinOpInst, Const, ConstIdx, ICmpInst, Inst, InstIdx, Module, Operand, Predicate, Ty, +}; use crate::compile::CompilationError; -mod simple; +mod analyse; + +use analyse::Analyse; + +struct Opt { + m: Module, + an: Analyse, +} + +impl Opt { + fn new(m: Module) -> Self { + let an = Analyse::new(&m); + Self { m, an } + } + + fn opt(mut self) -> Result { + for iidx in self.m.iter_all_inst_idxs() { + match self.m.inst_raw(iidx) { + #[cfg(test)] + Inst::BlackBox(_) => (), + Inst::Const(cidx) => self.an.set_value(iidx, Value::Const(cidx)), + Inst::BinOp(x) => match x.binop() { + BinOp::Add => (), + BinOp::Mul => match ( + self.an.op_map(&self.m, x.lhs(&self.m)), + self.an.op_map(&self.m, x.rhs(&self.m)), + ) { + (Operand::Const(cidx), Operand::Var(copy_iidx)) + | (Operand::Var(copy_iidx), Operand::Const(cidx)) => { + match self.m.const_(cidx) { + Const::Int(_, 0) => { + // Replace `x * 0` with `0`. + self.m.replace(iidx, Inst::Const(cidx)); + } + Const::Int(_, 1) => { + // Replace `x * 1` with `x`. + self.m.replace(iidx, Inst::Copy(copy_iidx)); + } + Const::Int(ty_idx, x) if x.is_power_of_two() => { + // Replace `x * y` with `x << ...`. + let shl = u64::from(x.ilog2()); + let shl_op = Operand::Const( + self.m.insert_const(Const::Int(*ty_idx, shl))?, + ); + let new_inst = + BinOpInst::new(Operand::Var(copy_iidx), BinOp::Shl, shl_op) + .into(); + self.m.replace(iidx, new_inst); + } + _ => (), + } + } + (Operand::Const(lhs_cidx), Operand::Const(rhs_cidx)) => { + let lhs_c = self.m.const_(lhs_cidx); + let rhs_c = self.m.const_(rhs_cidx); + match (lhs_c, rhs_c) { + (Const::Int(lhs_ty, lhs_v), Const::Int(rhs_ty, rhs_v)) => { + debug_assert_eq!(lhs_ty, rhs_ty); + let Ty::Integer(bits) = self.m.type_(*lhs_ty) else { + panic!() + }; + let mul = lhs_v.sign_extend(*bits) * rhs_v.sign_extend(*bits); + let trun = mul.truncate(*bits); + let cidx = self.m.insert_const(lhs_c.u64_to_int(trun))?; + self.m.replace(iidx, Inst::Const(cidx)); + } + _ => todo!(), + } + } + (Operand::Var(_), Operand::Var(_)) => (), + }, + _ => (), + }, + Inst::ICmp(x) => { + self.icmp(iidx, x); + } + Inst::Guard(x) => { + if let Operand::Const(_) = self.an.op_map(&self.m, x.cond(&self.m)) { + // A guard that references a constant is, by definition, not needed and + // doesn't affect future analyses. + self.m.replace(iidx, Inst::Tombstone); + } else { + self.an.guard(&self.m, x); + } + } + _ => (), + } + } + // FIXME: When code generation supports backwards register allocation, we won't need to + // explicitly perform dead code elimination and this function can be made `#[cfg(test)]` only. + self.m.dead_code_elimination(); + Ok(self.m) + } + + /// Optimise an [ICmpInst]. + fn icmp(&mut self, iidx: InstIdx, inst: ICmpInst) { + let lhs = self.an.op_map(&self.m, inst.lhs(&self.m)); + let pred = inst.predicate(); + let rhs = self.an.op_map(&self.m, inst.rhs(&self.m)); + match (&lhs, &rhs) { + (&Operand::Const(lhs_cidx), &Operand::Const(rhs_cidx)) => { + self.icmp_both_const(iidx, lhs_cidx, pred, rhs_cidx) + } + (&Operand::Var(_), &Operand::Const(_)) | (&Operand::Const(_), &Operand::Var(_)) => (), + (&Operand::Var(_), &Operand::Var(_)) => (), + } + } + + /// Optimise an `ICmp` if both sides are constants. It is required that [Opt::op_map] has been + /// called on both `lhs` and `rhs` to obtain the `ConstIdx`s. + fn icmp_both_const(&mut self, iidx: InstIdx, lhs: ConstIdx, pred: Predicate, rhs: ConstIdx) { + let lhs_c = self.m.const_(lhs); + let rhs_c = self.m.const_(rhs); + match (lhs_c, rhs_c) { + (Const::Float(..), Const::Float(..)) => (), + (Const::Int(lhs_tyidx, x), Const::Int(rhs_tyidx, y)) => { + debug_assert_eq!(lhs_tyidx, rhs_tyidx); + // Constant fold comparisons of simple integers. + let x = *x; + let y = *y; + let r = match pred { + Predicate::Equal => x == y, + Predicate::NotEqual => x != y, + Predicate::UnsignedGreater => x > y, + Predicate::UnsignedGreaterEqual => x >= y, + Predicate::UnsignedLess => x < y, + Predicate::UnsignedLessEqual => x <= y, + Predicate::SignedGreater => (x as i64) > (y as i64), + Predicate::SignedGreaterEqual => (x as i64) >= (y as i64), + Predicate::SignedLess => (x as i64) < (y as i64), + Predicate::SignedLessEqual => (x as i64) <= (y as i64), + }; + + self.m.replace( + iidx, + Inst::Const(if r { + self.m.true_constidx() + } else { + self.m.false_constidx() + }), + ); + } + (Const::Ptr(x), Const::Ptr(y)) => { + // Constant fold comparisons of pointers. + let x = *x; + let y = *y; + let r = match pred { + Predicate::Equal => x == y, + Predicate::NotEqual => x != y, + Predicate::UnsignedGreater => x > y, + Predicate::UnsignedGreaterEqual => x >= y, + Predicate::UnsignedLess => x < y, + Predicate::UnsignedLessEqual => x <= y, + Predicate::SignedGreater => (x as i64) > (y as i64), + Predicate::SignedGreaterEqual => (x as i64) >= (y as i64), + Predicate::SignedLess => (x as i64) < (y as i64), + Predicate::SignedLessEqual => (x as i64) <= (y as i64), + }; + + self.m.replace( + iidx, + Inst::Const(if r { + self.m.true_constidx() + } else { + self.m.false_constidx() + }), + ); + } + _ => unreachable!(), + } + } +} + +trait SignExtend { + fn sign_extend(&self, bits: u32) -> Self; +} + +impl SignExtend for u64 { + fn sign_extend(&self, bits: u32) -> Self { + debug_assert!( + bits > 0 && bits <= Self::BITS, + "{bits} outside range 1..={}", + Self::BITS + ); + let shift = Self::BITS - bits; + (*self << shift) >> shift + } +} + +trait Truncate { + fn truncate(&self, bits: u32) -> Self; +} + +impl Truncate for u64 { + fn truncate(&self, bits: u32) -> Self { + debug_assert!( + bits > 0 && bits <= Self::BITS, + "{bits} outside range 1..={}", + Self::BITS + ); + *self & ((1 as Self).wrapping_shl(bits) - 1) + } +} + +#[derive(Clone, Debug)] +enum Value { + Unknown, + Const(ConstIdx), +} /// Create JIT IR from the (`aot_mod`, `ta_iter`) tuple. pub(super) fn opt(m: Module) -> Result { - let mut m = simple::simple(m)?; - // FIXME: When code generation supports backwards register allocation, we won't need to - // explicitly perform dead code elimination and this function can be made `#[cfg(test)]` only. - m.dead_code_elimination(); - Ok(m) + Opt::new(m).opt() +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn opt_const_guard() { + Module::assert_ir_transform_eq( + " + entry: + %0: i1 = load_ti 0 + guard false, 0i1, [%0] + ", + |m| opt(m).unwrap(), + " + ... + entry: + ", + ); + } + + #[test] + fn opt_const_guard_indirect() { + Module::assert_ir_transform_eq( + " + entry: + %0: i1 = eq 0i8, 0i8 + guard true, %0, [] + %1: i1 = eq 0i8, 1i8 + guard false, %1, [%0] + ", + |m| opt(m).unwrap(), + " + ... + entry: + ", + ); + } + + #[test] + fn opt_const_guard_chain() { + Module::assert_ir_transform_eq( + " + entry: + %0: i8 = load_ti 0 + %1: i8 = mul %0, 0i8 + %2: i1 = eq %1, 0i8 + guard true, %2, [%0, %1] + black_box %0 + ", + |m| opt(m).unwrap(), + " + ... + entry: + %0: i8 = load_ti ... + black_box %0 + ", + ); + } + + #[test] + fn opt_mul_zero() { + Module::assert_ir_transform_eq( + " + entry: + %0: i8 = load_ti 0 + %1: i8 = load_ti 1 + %2: i8 = mul %0, 0i8 + %3: i8 = add %1, %2 + %4: i8 = mul 0i8, %0 + %5: i8 = add %1, %2 + black_box %3 + black_box %5 + ", + |m| opt(m).unwrap(), + " + ... + entry: + %1: i8 = load_ti ... + %3: i8 = add %1, 0i8 + %5: i8 = add %1, 0i8 + black_box %3 + black_box %5 + ", + ); + } + + #[test] + fn opt_mul_one() { + Module::assert_ir_transform_eq( + " + entry: + %0: i8 = load_ti 0 + %1: i8 = load_ti 1 + %2: i8 = mul %0, 1i8 + %3: i8 = add %1, %2 + %4: i8 = mul 1i8, %0 + %5: i8 = add %1, %2 + black_box %3 + black_box %5 + ", + |m| opt(m).unwrap(), + " + ... + entry: + %0: i8 = load_ti ... + %1: i8 = load_ti ... + %3: i8 = add %1, %0 + %5: i8 = add %1, %0 + black_box %3 + black_box %5 + ", + ); + } + + #[test] + fn opt_mul_const() { + Module::assert_ir_transform_eq( + " + entry: + %0: i8 = load_ti 0 + %1: i8 = mul %0, 0i8 + %2: i8 = mul %0, 0i8 + %3: i8 = mul %1, %2 + black_box %3 + ", + |m| opt(m).unwrap(), + " + ... + entry: + black_box 0i8 + ", + ); + } + + #[test] + fn opt_mul_shl() { + Module::assert_ir_transform_eq( + " + entry: + %0: i64 = load_ti 0 + %1: i64 = mul %0, 2i64 + %2: i64 = mul %0, 4i64 + %3: i64 = mul %0, 4611686018427387904i64 + %4: i64 = mul %0, 9223372036854775807i64 + %5: i64 = mul %0, 12i64 + black_box %1 + black_box %2 + black_box %3 + black_box %4 + black_box %5 + ", + |m| opt(m).unwrap(), + " + ... + entry: + %0: i64 = load_ti ... + %1: i64 = shl %0, 1i64 + %2: i64 = shl %0, 2i64 + %3: i64 = shl %0, 62i64 + %4: i64 = mul %0, 9223372036854775807i64 + %5: i64 = mul %0, 12i64 + black_box ... + ... + ", + ); + } + + #[test] + fn opt_icmp_const() { + Module::assert_ir_transform_eq( + " + entry: + %0: i1 = eq 0i8, 0i8 + %1: i1 = eq 0i8, 1i8 + %2: i1 = ne 0i8, 0i8 + %3: i1 = ne 0i8, 1i8 + %4: i1 = ugt 0i8, 0i8 + %5: i1 = ugt 0i8, 1i8 + %6: i1 = ugt 1i8, 0i8 + %7: i1 = uge 0i8, 0i8 + %8: i1 = uge 0i8, 1i8 + %9: i1 = uge 1i8, 0i8 + %10: i1 = ult 0i8, 0i8 + %11: i1 = ult 0i8, 1i8 + %12: i1 = ult 1i8, 0i8 + %13: i1 = ule 0i8, 0i8 + %14: i1 = ule 0i8, 1i8 + %15: i1 = ule 1i8, 0i8 + %16: i1 = sgt 0i8, 0i8 + %17: i1 = sgt 0i8, -1i8 + %18: i1 = sgt -1i8, 0i8 + %19: i1 = sge 0i8, 0i8 + %20: i1 = sge 0i8, -1i8 + %21: i1 = sge -1i8, 0i8 + %22: i1 = slt 0i8, 0i8 + %23: i1 = slt 0i8, -1i8 + %24: i1 = slt -1i8, 0i8 + %25: i1 = sle 0i8, 0i8 + %26: i1 = sle 0i8, -1i8 + %27: i1 = sle -1i8, 0i8 + black_box %0 + black_box %1 + black_box %2 + black_box %3 + black_box %4 + black_box %5 + black_box %6 + black_box %7 + black_box %8 + black_box %9 + black_box %10 + black_box %11 + black_box %12 + black_box %13 + black_box %14 + black_box %15 + black_box %16 + black_box %17 + black_box %18 + black_box %19 + black_box %20 + black_box %21 + black_box %22 + black_box %23 + black_box %24 + black_box %25 + black_box %26 + black_box %27 + ", + |m| opt(m).unwrap(), + " + ... + entry: + black_box 1i1 + black_box 0i1 + black_box 0i1 + black_box 1i1 + black_box 0i1 + black_box 0i1 + black_box 1i1 + black_box 1i1 + black_box 0i1 + black_box 1i1 + black_box 0i1 + black_box 1i1 + black_box 0i1 + black_box 1i1 + black_box 1i1 + black_box 0i1 + black_box 0i1 + black_box 1i1 + black_box 0i1 + black_box 1i1 + black_box 1i1 + black_box 0i1 + black_box 0i1 + black_box 0i1 + black_box 1i1 + black_box 1i1 + black_box 0i1 + black_box 1i1 + ", + ); + } } diff --git a/ykrt/src/compile/jitc_yk/opt/simple.rs b/ykrt/src/compile/jitc_yk/opt/simple.rs deleted file mode 100644 index b50022d6d..000000000 --- a/ykrt/src/compile/jitc_yk/opt/simple.rs +++ /dev/null @@ -1,429 +0,0 @@ -//! Simple, local optimisations. -//! -//! These include strength reductions and other optimisations that can be performed with little -//! analysis. - -use crate::compile::{ - jitc_yk::jit_ir::{ - BinOp, BinOpInst, GuardInst, ICmpInst, Inst, InstIdx, Module, Operand, PackedOperand, - Predicate, - }, - CompilationError, -}; - -pub(super) fn simple(mut m: Module) -> Result { - for iidx in m.iter_all_inst_idxs() { - let (iidx, inst) = m.inst_decopy(iidx); - match inst { - Inst::BinOp(BinOpInst { - lhs, - binop: BinOp::Mul, - rhs, - }) => opt_mul(&mut m, iidx, lhs, rhs)?, - Inst::Guard(x) => opt_guard(&mut m, iidx, x)?, - Inst::ICmp(x) => opt_icmp(&mut m, iidx, x)?, - _ => (), - } - } - Ok(m) -} - -fn opt_mul( - m: &mut Module, - iidx: InstIdx, - lhs: PackedOperand, - rhs: PackedOperand, -) -> Result<(), CompilationError> { - match (lhs.unpack(m), rhs.unpack(m)) { - (Operand::Var(mut mul_inst), Operand::Const(mul_const)) - | (Operand::Const(mul_const), Operand::Var(mut mul_inst)) => { - let old_const = m.const_(mul_const); - if let Some(old_val) = old_const.int_to_u64() { - let mut new_val = old_val; - // If we've got `%2: mul %1, xi8` then see if `%1` is of the form `mul %0, yi8`: if so - // we've got a chain that's `%2: %0*x*y`. We can thus "skip" the intermediate `mul` - // when calculating the constant we're going to optimise. - if let Inst::BinOp(BinOpInst { - lhs: chain_lhs, - binop: BinOp::Mul, - rhs: chain_rhs, - }) = m.inst_no_copies(mul_inst) - { - if let (Operand::Var(chain_mul_inst), Operand::Const(chain_mul_const)) - | (Operand::Const(chain_mul_const), Operand::Var(chain_mul_inst)) = - (chain_lhs.unpack(m), chain_rhs.unpack(m)) - { - if let Some(y) = m.const_(chain_mul_const).int_to_u64() { - mul_inst = chain_mul_inst; - new_val = old_val * y; - } - } - } - - if new_val == 0 { - // Replace `x * 0` with `0`. - let cidx = m.insert_const(old_const.u64_to_int(0))?; - m.replace(iidx, Inst::Const(cidx)); - } else if new_val == 1 { - // Replace `x * 1` with `x`. - m.replace(iidx, Inst::Copy(mul_inst)); - } else if new_val.is_power_of_two() { - // Replace `x * y` with `x << ...`. - let shl = u64::from(new_val.ilog2()); - let new_const = Operand::Const(m.insert_const(old_const.u64_to_int(shl))?); - let new_inst = - BinOpInst::new(Operand::Var(mul_inst), BinOp::Shl, new_const).into(); - m.replace(iidx, new_inst); - } else if new_val != old_val { - let new_const = Operand::Const(m.insert_const(old_const.u64_to_int(new_val))?); - let new_inst = - BinOpInst::new(Operand::Var(mul_inst), BinOp::Mul, new_const).into(); - m.replace(iidx, new_inst); - } - } - } - (Operand::Const(x), Operand::Const(y)) => { - // Constant fold the unsigned multiplication of two constants. - let x = m.const_(x); - let y = m.const_(y); - // If `x_val * y_val` overflows, we're fine with the UB, as the interpreter - // author is at fault. - let new_val = x.int_to_u64().unwrap() * y.int_to_u64().unwrap(); - let new_const = m.insert_const(x.u64_to_int(new_val))?; - m.replace(iidx, Inst::Const(new_const)); - } - (Operand::Var(_), Operand::Var(_)) => (), - } - Ok(()) -} - -fn opt_icmp( - m: &mut Module, - iidx: InstIdx, - ICmpInst { lhs, pred, rhs }: ICmpInst, -) -> Result<(), CompilationError> { - if let (Operand::Const(x), Operand::Const(y)) = (lhs.unpack(m), rhs.unpack(m)) { - if let (Some(x), Some(y)) = (m.const_(x).int_to_u64(), m.const_(y).int_to_u64()) { - // Constant fold comparisons of simple integers. Note that we have to follow the - // LLVM semantics carefully. The quotes in the `match` below are from - // https://llvm.org/docs/LangRef.html#icmp-instruction. - let r = match pred { - // "eq: yields true if the operands are equal, false otherwise. No sign - // interpretation is necessary or performed." - Predicate::Equal => x == y, - // "ne: yields true if the operands are unequal, false otherwise. No sign - // interpretation is necessary or performed." - Predicate::NotEqual => x != y, - // "ugt: interprets the operands as unsigned values and yields true if op1 is - // greater than op2." - Predicate::UnsignedGreater => x > y, - // "uge: interprets the operands as unsigned values and yields true if op1 is - // greater than or equal to op2." - Predicate::UnsignedGreaterEqual => x >= y, - // "ult: interprets the operands as unsigned values and yields true if op1 is - // less than op2." - Predicate::UnsignedLess => x < y, - // "ule: interprets the operands as unsigned values and yields true if op1 is - // less than or equal to op2." - Predicate::UnsignedLessEqual => x <= y, - // "interprets the operands as signed values and yields true if op1 is greater - // than op2." - Predicate::SignedGreater => (x as i64) > (y as i64), - // "sge: interprets the operands as signed values and yields true if op1 is - // greater than or equal to op2." - Predicate::SignedGreaterEqual => (x as i64) >= (y as i64), - // "slt: interprets the operands as signed values and yields true if op1 is less - // than op2." - Predicate::SignedLess => (x as i64) < (y as i64), - // "sle: interprets the operands as signed values and yields true if op1 is less - // than or equal to op2." - Predicate::SignedLessEqual => (x as i64) <= (y as i64), - }; - - if r { - m.replace(iidx, Inst::Const(m.true_constidx())); - } else { - m.replace(iidx, Inst::Const(m.false_constidx())); - } - } - } - - Ok(()) -} - -fn opt_guard( - m: &mut Module, - iidx: InstIdx, - GuardInst { - cond, - expect: _, - gidx: _, - }: GuardInst, -) -> Result<(), CompilationError> { - if let Operand::Const(_) = cond.unpack(m) { - // A guard that references a constant is, by definition, not useful. - m.replace(iidx, Inst::Tombstone); - } - Ok(()) -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn opt_mul_zero() { - Module::assert_ir_transform_eq( - " - entry: - %0: i8 = load_ti 0 - %1: i8 = load_ti 1 - %2: i8 = mul %0, 0i8 - %3: i8 = add %1, %2 - %4: i8 = mul 0i8, %0 - %5: i8 = add %1, %2 - ", - |m| simple(m).unwrap(), - " - ... - entry: - %0: i8 = load_ti ... - %1: i8 = load_ti ... - %3: i8 = add %1, 0i8 - %5: i8 = add %1, 0i8 - ", - ); - } - - #[test] - fn opt_mul_one() { - Module::assert_ir_transform_eq( - " - entry: - %0: i8 = load_ti 0 - %1: i8 = load_ti 1 - %2: i8 = mul %0, 1i8 - %3: i8 = add %1, %2 - %4: i8 = mul 1i8, %0 - %5: i8 = add %1, %2 - ", - |m| simple(m).unwrap(), - " - ... - entry: - %0: i8 = load_ti ... - %1: i8 = load_ti ... - %3: i8 = add %1, %0 - %5: i8 = add %1, %0 - ", - ); - } - - #[test] - fn opt_mul_const() { - Module::assert_ir_transform_eq( - " - entry: - %0: i8 = load_ti 0 - %1: i8 = mul %0, 0i8 - %2: i8 = mul %0, 0i8 - %3: i8 = mul %1, %2 - black_box %3 - ", - |m| simple(m).unwrap(), - " - ... - entry: - %0: i8 = load_ti ... - black_box 0i8 - ", - ); - } - - #[test] - fn opt_mul_chain() { - Module::assert_ir_transform_eq( - " - entry: - %0: i8 = load_ti 0 - %1: i8 = mul %0, 3i8 - %2: i8 = mul %1, 4i8 - %3: i8 = mul %2, 5i8 - black_box %3 - ", - |m| simple(m).unwrap(), - " - ... - entry: - %0: i8 = load_ti ... - %1: i8 = mul %0, 3i8 - %2: i8 = mul %0, 12i8 - %3: i8 = mul %0, 60i8 - black_box %3 - ", - ); - } - - #[test] - fn opt_mul_shl() { - Module::assert_ir_transform_eq( - " - entry: - %0: i64 = load_ti 0 - %1: i64 = mul %0, 2i64 - %2: i64 = mul %0, 4i64 - %3: i64 = mul %0, 4611686018427387904i64 - %4: i64 = mul %0, 9223372036854775807i64 - %5: i64 = mul %0, 12i64 - ", - |m| simple(m).unwrap(), - " - ... - entry: - %0: i64 = load_ti ... - %1: i64 = shl %0, 1i64 - %2: i64 = shl %0, 2i64 - %3: i64 = shl %0, 62i64 - %4: i64 = mul %0, 9223372036854775807i64 - %5: i64 = mul %0, 12i64 - ", - ); - } - - #[test] - fn opt_icmp_const() { - Module::assert_ir_transform_eq( - " - entry: - %0: i1 = eq 0i8, 0i8 - %1: i1 = eq 0i8, 1i8 - %2: i1 = ne 0i8, 0i8 - %3: i1 = ne 0i8, 1i8 - %4: i1 = ugt 0i8, 0i8 - %5: i1 = ugt 0i8, 1i8 - %6: i1 = ugt 1i8, 0i8 - %7: i1 = uge 0i8, 0i8 - %8: i1 = uge 0i8, 1i8 - %9: i1 = uge 1i8, 0i8 - %10: i1 = ult 0i8, 0i8 - %11: i1 = ult 0i8, 1i8 - %12: i1 = ult 1i8, 0i8 - %13: i1 = ule 0i8, 0i8 - %14: i1 = ule 0i8, 1i8 - %15: i1 = ule 1i8, 0i8 - %16: i1 = sgt 0i8, 0i8 - %17: i1 = sgt 0i8, -1i8 - %18: i1 = sgt -1i8, 0i8 - %19: i1 = sge 0i8, 0i8 - %20: i1 = sge 0i8, -1i8 - %21: i1 = sge -1i8, 0i8 - %22: i1 = slt 0i8, 0i8 - %23: i1 = slt 0i8, -1i8 - %24: i1 = slt -1i8, 0i8 - %25: i1 = sle 0i8, 0i8 - %26: i1 = sle 0i8, -1i8 - %27: i1 = sle -1i8, 0i8 - black_box %0 - black_box %1 - black_box %2 - black_box %3 - black_box %4 - black_box %5 - black_box %6 - black_box %7 - black_box %8 - black_box %9 - black_box %10 - black_box %11 - black_box %12 - black_box %13 - black_box %14 - black_box %15 - black_box %16 - black_box %17 - black_box %18 - black_box %19 - black_box %20 - black_box %21 - black_box %22 - black_box %23 - black_box %24 - black_box %25 - black_box %26 - black_box %27 - ", - |m| simple(m).unwrap(), - " - ... - entry: - black_box 1i1 - black_box 0i1 - black_box 0i1 - black_box 1i1 - black_box 0i1 - black_box 0i1 - black_box 1i1 - black_box 1i1 - black_box 0i1 - black_box 1i1 - black_box 0i1 - black_box 1i1 - black_box 0i1 - black_box 1i1 - black_box 1i1 - black_box 0i1 - black_box 0i1 - black_box 1i1 - black_box 0i1 - black_box 1i1 - black_box 1i1 - black_box 0i1 - black_box 0i1 - black_box 0i1 - black_box 1i1 - black_box 1i1 - black_box 0i1 - black_box 1i1 - ", - ); - } - - #[test] - fn opt_const_guard() { - Module::assert_ir_transform_eq( - " - entry: - %0: i1 = eq 0i8, 0i8 - guard true, %0, [] - %1: i1 = eq 0i8, 1i8 - guard false, %1, [%0] - ", - |m| simple(m).unwrap(), - " - ... - entry: - ", - ); - } - - #[test] - fn opt_const_guard_chain() { - Module::assert_ir_transform_eq( - " - entry: - %0: i8 = load_ti 0 - %1: i8 = mul %0, 0i8 - %2: i1 = eq %1, 0i8 - guard true, %2, [%0, %1] - black_box %0 - ", - |m| simple(m).unwrap(), - " - ... - entry: - %0: i8 = load_ti ... - black_box %0 - ", - ); - } -}