diff --git a/Cargo.lock b/Cargo.lock index 9b5e381..5aa526f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1,6 +1,6 @@ # This file is automatically @generated by Cargo. # It is not intended for manual editing. -version = 3 +version = 4 [[package]] name = "anstream" @@ -51,6 +51,12 @@ dependencies = [ "windows-sys", ] +[[package]] +name = "bumpalo" +version = "3.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "79296716171880943b8470b5f8d03aa55eb2e645a4874bdbb28adb49162e012c" + [[package]] name = "cfg-if" version = "1.0.0" @@ -103,6 +109,27 @@ version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0b6a852b24ab71dffc585bcb46eaf7959d175cb865a7152e35b348d1b2960422" +[[package]] +name = "derive_more" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4a9b99b9cbbe49445b21764dc0625032a89b145a2642e67603e1c936f5458d05" +dependencies = [ + "derive_more-impl", +] + +[[package]] +name = "derive_more-impl" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cb7330aeadfbe296029522e6c40f315320aba36fc43a5b3632f3795348f3bd22" +dependencies = [ + "proc-macro2", + "quote", + "syn", + "unicode-xid", +] + [[package]] name = "heck" version = "0.5.0" @@ -147,7 +174,9 @@ checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3" name = "meraki" version = "0.1.0" dependencies = [ + "bumpalo", "clap", + "derive_more", "indoc", "libloading", "serde", @@ -254,6 +283,12 @@ version = "1.0.12" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" +[[package]] +name = "unicode-xid" +version = "0.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ebc1c04c71510c7f702b52b7c350734c9ff1295c464a03335b00bb84fc54f853" + [[package]] name = "utf8parse" version = "0.2.2" diff --git a/Cargo.toml b/Cargo.toml index 09966f8..387a460 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,7 +6,9 @@ edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] +bumpalo = "3.16.0" clap = { version = "4.5.9", features = ["derive"] } +derive_more = { version = "1.0.0", features = ["display"] } indoc = "2.0.5" libloading = "0.8.5" thiserror = "1.0.64" diff --git a/src/archs/amd64/amd64.rs b/src/archs/amd64/amd64.rs deleted file mode 100644 index 1ca6c80..0000000 --- a/src/archs/amd64/amd64.rs +++ /dev/null @@ -1,1194 +0,0 @@ -use crate::{ - archs::{ArchError, Architecture, Jump}, - codegen::{ - operands::{self, Base, EffectiveAddress, Immediate, Memory, Offset}, - Argument, Destination, Source, - }, - parser::{BitwiseOp, Block, CmpOp, Stmt}, - register::{ - allocator::{AllocatorError, RegisterAllocator}, - Register, - }, - scope::Scope, - symbol_table::{Symbol, SymbolTableError}, - types::{IntType, Type, UintType}, -}; -use indoc::formatdoc; -use std::collections::HashMap; - -#[derive(Eq, PartialEq, Hash)] -enum ParamClass { - Integer, -} - -impl From<&Type> for ParamClass { - fn from(value: &Type) -> Self { - match value { - Type::Int(_) - | Type::UInt(_) - | Type::Bool - | Type::Ptr(_) - | Type::Array(_) - | Type::Null => Self::Integer, - _ => unreachable!("Unsupported parameter type"), - } - } -} - -#[derive(Clone)] -pub struct Amd64 { - buf: String, - registers: RegisterAllocator, - rax: Register, - rbp: Register, - rcx: Register, - rdx: Register, - literals: Vec<(String, String)>, - label_counter: usize, -} - -impl Architecture for Amd64 { - fn new() -> Self { - let rdx = Register::new("dl", "dx", "edx", "rdx"); - - Self { - buf: String::new(), - rax: Register::new("al", "ax", "eax", "rax"), - rbp: Register::new("there's no one byte one, hmmmm", "bp", "ebp", "rbp"), - rcx: Register::new("cl", "cx", "ecx", "rcx"), - rdx, - registers: RegisterAllocator::new(vec![ - Register::new("r15b", "r15w", "r15d", "r15"), - Register::new("r14b", "r14w", "r14d", "r14"), - Register::new("r13b", "r13w", "r13d", "r13"), - Register::new("r12b", "r12w", "r12d", "r12"), - Register::new("r11b", "r11w", "r11d", "r11"), - Register::new("r10b", "r10w", "r10d", "r10"), - Register::new("r9b", "r9w", "r9d", "r9"), - Register::new("r8b", "r8w", "r8d", "r8"), - Register::new("cl", "cx", "ecx", "rcx"), - rdx, - Register::new("sil", "si", "esi", "rsi"), - Register::new("dil", "di", "edi", "rdi"), - ]), - literals: Vec::new(), - label_counter: 0, - } - } - - #[inline] - fn word_size(&self) -> usize { - 8 - } - - #[inline] - fn stack_alignment(&self) -> usize { - 16 - } - - fn size(&self, type_: &Type, scope: &Scope) -> usize { - match type_ { - Type::Ptr(_) - | Type::Null - | Type::UInt(UintType::Usize) - | Type::Int(IntType::Isize) - | Type::Fn(_, _) => self.word_size(), - Type::Custom(structure) => match scope.find_type(structure).unwrap() { - crate::type_table::Type::Struct(structure) => self.struct_size(structure, scope), - }, - Type::Array(array) => self.size(&array.type_, scope) * array.length, - type_ => type_ - .size() - .expect(&format!("Failed to get size of type {type_}")), - } - } - - fn alloc(&mut self) -> Result { - self.registers.alloc() - } - - fn free(&mut self, register: Register) -> Result<(), AllocatorError> { - self.registers.free(register) - } - - fn size_name(size: usize) -> &'static str { - match size { - 1 => "byte ptr", - 2 => "word ptr", - 4 => "dword ptr", - 8 => "qword ptr", - _ => unreachable!(), - } - } - - fn mov(&mut self, src: &Source, dest: &Destination, signed: bool) -> Result<(), ArchError> { - match (dest, src) { - (Destination::Memory(dest), Source::Memory(src)) => { - let size = src.size; - let r = self.alloc()?; - - self.lea(&r.dest(self.word_size()), &src.effective_address); - - for chunk_size in Self::size_iter(size) { - let r_tmp = self.alloc()?; - - self.mov( - &Source::Memory(Memory { - effective_address: EffectiveAddress { - base: Base::Register(r), - index: None, - scale: None, - displacement: Some(Offset((size - chunk_size).try_into().unwrap())), - }, - size: chunk_size, - }), - &Destination::Register(operands::Register { - register: r_tmp, - size: chunk_size, - }), - false, - )?; - self.mov( - &Source::Register(operands::Register { - register: r_tmp, - size: chunk_size, - }), - &Destination::Memory(Memory { - effective_address: EffectiveAddress { - base: dest.effective_address.base.clone(), - index: None, - scale: None, - displacement: Some( - dest.effective_address - .displacement - .as_ref() - .unwrap_or(&Offset(0)) - - &Offset((size - chunk_size).try_into().unwrap()), - ), - }, - size: chunk_size, - }), - false, - )?; - self.free(r_tmp)?; - } - - self.free(r)?; - } - (dest, src) => { - let dest_size = dest.size(); - let src_size = src.size().unwrap_or_else(|| self.word_size()); - - if dest_size == 8 && src_size == 4 { - // On x86_64 you can move 32bit value in 32bit register, and upper 32bits of the register will be zeroed - self.mov( - src, - &Destination::Register(operands::Register { - register: self.rax, - size: 4, - }), - false, - )?; - - if signed { - self.buf.push_str("\tcdqe\n"); - } - - self.mov( - &Source::Register(operands::Register { - register: self.rax, - size: 8, - }), - dest, - false, - )?; - } else if dest_size > src_size { - if signed { - self.buf.push_str(&formatdoc!("\tmovsx {dest}, {src}\n")); - } else { - self.buf.push_str(&formatdoc!("\tmovzx {dest}, {src}\n")); - } - } else { - self.buf.push_str(&formatdoc!("\tmov {dest}, {src}\n")); - } - } - }; - - Ok(()) - } - - fn declare(&mut self, name: &str, size: usize) { - self.buf.push_str(&formatdoc!( - " - \t.comm {} {} - ", - name, - size, - )); - } - - fn negate(&mut self, dest: &Destination) { - self.buf.push_str(&formatdoc!( - " - \tneg {dest} - ", - )); - } - - fn add( - &mut self, - lhs: &Source, - rhs: &Source, - dest: &Destination, - signed: bool, - ) -> Result<(), ArchError> { - lhs.size().map(|size| assert_eq!(size, dest.size())); - rhs.size().map(|size| assert_eq!(size, dest.size())); - assert!(!(lhs == dest && rhs == dest)); - - let lhs = if let Source::Immediate(_) = lhs { - self.mov(lhs, &self.rax.dest(dest.size()), signed)?; - &self.rax.source(dest.size()) - } else { - lhs - }; - - self.buf.push_str(&formatdoc!( - " - \tadd {lhs}, {rhs} - " - )); - - if lhs != dest { - self.mov(lhs, dest, signed)?; - } - - Ok(()) - } - - fn sub( - &mut self, - lhs: &Source, - rhs: &Source, - dest: &Destination, - signed: bool, - ) -> Result<(), ArchError> { - lhs.size().map(|size| assert_eq!(size, dest.size())); - rhs.size().map(|size| assert_eq!(size, dest.size())); - assert!(!(lhs == dest && rhs == dest)); - - let lhs = if let Source::Immediate(_) = lhs { - self.mov(lhs, &self.rax.dest(dest.size()), signed)?; - &self.rax.source(dest.size()) - } else { - lhs - }; - - self.buf.push_str(&formatdoc!( - " - \tsub {lhs}, {rhs} - " - )); - - if lhs != dest { - self.mov(lhs, dest, signed)?; - } - - Ok(()) - } - - fn mul( - &mut self, - lhs: &Source, - rhs: &Source, - dest: &Destination, - signed: bool, - ) -> Result<(), ArchError> { - lhs.size().map(|size| assert_eq!(size, dest.size())); - rhs.size().map(|size| assert_eq!(size, dest.size())); - assert!(!(lhs == dest && rhs == dest)); - - self.mov( - lhs, - &self - .rax - .dest(lhs.size().unwrap_or_else(|| self.word_size())), - signed, - )?; - if rhs != dest { - self.mov(rhs, dest, signed)?; - } - if self.registers.is_used(&self.rdx) { - self.push(&self.rdx.source(self.word_size())); - } - self.buf.push_str(&formatdoc!( - " - \timul {dest} - ", - )); - if self.registers.is_used(&self.rdx) { - self.pop(&self.rdx.dest(self.word_size())); - } - self.mov(&self.rax.source(dest.size()), dest, signed)?; - - Ok(()) - } - - //NOTE: if mafs doesn't works, probably because of this xd - fn div( - &mut self, - lhs: &Source, - rhs: &Source, - dest: &Destination, - signed: bool, - ) -> Result<(), ArchError> { - lhs.size().map(|size| assert_eq!(size, dest.size())); - rhs.size().map(|size| assert_eq!(size, dest.size())); - assert!(!(lhs == dest && rhs == dest)); - - self.mov(lhs, &self.rax.dest(self.word_size()), signed)?; - self.mov(rhs, dest, signed)?; - if self.registers.is_used(&self.rdx) { - self.push(&self.rdx.source(self.word_size())); - } - self.buf.push_str(&formatdoc!( - " - \tcqo - \tidiv {dest} - ", - )); - if self.registers.is_used(&self.rdx) { - self.pop(&self.rdx.dest(self.word_size())); - } - self.mov(&self.rax.source(dest.size()), dest, signed)?; - - Ok(()) - } - - fn bitwise( - &mut self, - lhs: &Source, - rhs: &Source, - dest: &Destination, - op: BitwiseOp, - signed: bool, - ) -> Result<(), ArchError> { - lhs.size().map(|size| assert_eq!(size, dest.size())); - rhs.size().map(|size| assert_eq!(size, dest.size())); - assert!(!(lhs == dest && rhs == dest)); - - let lhs = if let Source::Immediate(_) = lhs { - self.mov(lhs, &self.rax.dest(dest.size()), signed)?; - &self.rax.source(dest.size()) - } else { - lhs - }; - - match op { - BitwiseOp::And => { - self.buf.push_str(&formatdoc!( - " - \tand {lhs}, {rhs} - " - )); - } - BitwiseOp::Or => { - self.buf.push_str(&formatdoc!( - " - \tor {lhs}, {rhs} - " - )); - } - }; - - if lhs != dest { - self.mov(lhs, dest, signed)?; - } - - Ok(()) - } - - fn bitwise_not(&mut self, dest: &Destination) { - self.buf.push_str(&formatdoc!( - " - \tnot {dest} - " - )); - } - - fn cmp(&mut self, dest: &Destination, src: &Source) { - self.buf.push_str(&formatdoc!( - " - \tcmp {dest}, {src} - ", - )); - } - - fn setcc(&mut self, dest: &Destination, condition: CmpOp) { - let ins = match condition { - CmpOp::LessThan => "setl", - CmpOp::LessEqual => "setle", - CmpOp::GreaterThan => "setg", - CmpOp::GreaterEqual => "setge", - CmpOp::Equal => "sete", - CmpOp::NotEqual => "setne", - }; - - self.buf.push_str(&formatdoc!( - " - \t{ins} {dest} - ", - )); - } - - fn fn_preamble( - &mut self, - name: &str, - params: &[Type], - stackframe: usize, - scope: &Scope, - ) -> Result<(), ArchError> { - self.buf.push_str(&formatdoc!( - " - .global {name} - {name}: - ", - )); - - if stackframe > 0 { - self.buf.push_str(&formatdoc!( - " - \tpush rbp - \tmov rbp, rsp - \tsub rsp, {stackframe} - " - )); - } - - let mut occurences: HashMap = HashMap::new(); - let mut offset = Offset(0); - - for type_ in params { - let n = occurences.entry(ParamClass::from(type_)).or_insert(0); - *n += 1; - - match ParamClass::from(type_) { - ParamClass::Integer => { - if *n <= 6 { - let n = *n; - let size = self.size(type_, scope); - offset = &offset - (size as isize); - - self.mov( - &Source::Register(operands::Register { - register: self.registers.get(self.registers.len() - n).unwrap(), - size, - }), - &Destination::Memory(Memory { - effective_address: EffectiveAddress { - base: Base::Register(self.rbp), - index: None, - scale: None, - displacement: Some(offset.clone()), - }, - size, - }), - type_.signed(), - )?; - } - } - } - } - - Ok(()) - } - - fn fn_postamble(&mut self, name: &str, stackframe: usize) { - assert_eq!( - self.registers.used().len(), - 0, - "Function '{name}' didn't free all registers. {:?} are still allocated", - self.registers.used() - ); - - self.buf.push_str(&formatdoc!( - " - {name}_ret: - " - )); - - if stackframe > 0 { - self.buf.push_str(&formatdoc!( - " - \tleave - " - )); - } - - self.buf.push_str(&formatdoc!( - " - \tret - " - )) - } - - fn ret(&mut self, src: &Source, signed: bool) -> Result<(), ArchError> { - self.mov(&src, &self.rax.dest(self.word_size()), signed) - } - - fn jcc(&mut self, label: &str, kind: Jump) { - let ins = match kind { - Jump::Unconditional => "jmp", - Jump::Equal => "je", - Jump::NotEqual => "jne", - Jump::GreaterThan => "jg", - Jump::GreaterEqual => "jge", - Jump::LessThan => "jl", - Jump::LessEqual => "jle", - }; - - self.buf.push_str(&formatdoc!( - " - \t{ins} {label} - " - )); - } - - fn call( - &mut self, - src: &Source, - dest: Option<&Destination>, - signed: bool, - size: usize, - ) -> Result<(), ArchError> { - self.buf.push_str(&formatdoc!( - " - \tcall {src} - ", - )); - - if let Some(dest) = dest { - self.mov( - &Source::Register(operands::Register { - register: self.rax, - size, - }), - dest, - signed, - )?; - } - - Ok(()) - } - - fn push_arg(&mut self, src: Source, type_: &Type, preceding: &[Type]) -> Argument { - let mut occurences: HashMap = HashMap::new(); - let class = ParamClass::from(type_); - - preceding - .iter() - .for_each(|param| *occurences.entry(ParamClass::from(param)).or_insert(0) += 1); - - match class { - ParamClass::Integer => match occurences.get(&class).unwrap_or(&0) + 1 { - n if n <= 6 => { - let r = self.registers.alloc_nth(self.registers.len() - n).unwrap(); - self.mov(&src, &r.dest(self.word_size()), type_.signed()) - .unwrap(); - - Argument::Register(r) - } - _ => { - self.push(&src); - - Argument::Stack(self.word_size()) - } - }, - } - } - - fn lea(&mut self, dest: &Destination, address: &EffectiveAddress) { - self.buf.push_str(&formatdoc!( - " - \tlea {dest}, {address} - " - )); - } - - fn populate_offsets( - &mut self, - block: &mut Block, - scope: &Scope, - mut offset: isize, - ) -> Result { - let mut occurences: HashMap = HashMap::new(); - - for param in block.scope.symbol_table.iter_mut().filter_map(|symbol| { - if let Symbol::Param(param) = symbol { - Some(param) - } else { - None - } - }) { - let n = occurences - .entry(ParamClass::from(¶m.type_)) - .or_insert(0); - *n += 1; - - if *n <= 6 { - offset -= self.size(¶m.type_, scope) as isize; - param.offset = Offset(offset); - } else { - // When call instruction is called it pushes return address on da stack - param.offset = Offset(((*n - 6) * self.word_size() + 8) as isize); - } - } - - for stmt in &mut block.statements { - match stmt { - Stmt::VarDecl(stmt2) => { - offset -= self.size(&stmt2.type_, scope) as isize; - - match block.scope.symbol_table.find_mut(&stmt2.name).unwrap() { - Symbol::Local(local) => { - local.offset = Offset(offset); - } - _ => unreachable!(), - }; - } - Stmt::If(stmt) => { - offset = self.populate_offsets(&mut stmt.consequence, scope, offset)?; - - if let Some(alternative) = &mut stmt.alternative { - offset = self.populate_offsets(alternative, scope, offset)?; - } - } - Stmt::While(stmt) => { - offset = self.populate_offsets(&mut stmt.block, scope, offset)?; - } - Stmt::For(stmt) => { - if let Some(Stmt::VarDecl(stmt2)) = stmt.initializer.as_deref() { - offset -= self.size(&stmt2.type_, scope) as isize; - - match stmt.block.scope.symbol_table.find_mut(&stmt2.name).unwrap() { - Symbol::Local(local) => { - local.offset = Offset(offset); - } - _ => unreachable!(), - }; - } - - offset = self.populate_offsets(&mut stmt.block, scope, offset)?; - } - Stmt::Return(_) | Stmt::Expr(_) | Stmt::Continue | Stmt::Break => (), - Stmt::Function(_) => unreachable!(), - } - } - - Ok(offset) - } - - fn shrink_stack(&mut self, size: usize) { - self.buf.push_str(&formatdoc!( - " - \tsub rsp, {size} - " - )); - } - - fn generate_label(&mut self) -> String { - let label = format!(".L{}", self.label_counter); - self.label_counter += 1; - - label - } - - fn write_label(&mut self, label: &str) { - self.buf.push_str(&format!("{label}:\n")); - } - - fn define_literal(&mut self, literal: String) -> String { - let label = self.generate_label(); - - self.literals.push((label.clone(), literal)); - - label - } - - fn array_offset( - &mut self, - base: &Source, - index: &Source, - size: usize, - dest: &Destination, - ) -> Result<(), ArchError> { - let r = self.alloc()?; - - self.mul( - index, - &Source::Immediate(Immediate::UInt(size as u64)), - &r.dest(dest.size()), - false, - )?; - self.add(base, &r.source(dest.size()), dest, false)?; - - self.free(r)?; - - Ok(()) - } - - fn shl(&mut self, dest: &Destination, src: &Source) -> Result<(), ArchError> { - self.mov( - src, - &Destination::Register(operands::Register { - register: self.rcx, - size: self.word_size(), - }), - false, - )?; - self.buf.push_str(&formatdoc!( - " - \tshl {dest}, {} - ", - self.rcx.from_size(1) - )); - - Ok(()) - } - - fn shr(&mut self, dest: &Destination, src: &Source) -> Result<(), ArchError> { - self.mov( - src, - &Destination::Register(operands::Register { - register: self.rcx, - size: self.word_size(), - }), - false, - )?; - self.buf.push_str(&formatdoc!( - " - \tshr {dest}, {} - ", - self.rcx.from_size(1) - )); - - Ok(()) - } - - fn finish(&mut self) -> Vec { - self.literals.iter().for_each(|(label, value)| { - self.buf.insert_str( - 0, - &formatdoc!( - " - {label}: - .string \"{value}\" - " - ), - ); - }); - self.buf.insert_str(0, ".section .text\n"); - self.buf.as_bytes().to_vec() - } - - fn push(&mut self, src: &Source) { - self.buf.push_str(&formatdoc!( - " - \tpush {src} - " - )); - } - - fn symbol_source(&self, name: &str, scope: &Scope) -> Result { - Ok( - match scope.find_symbol(name).ok_or(ArchError::SymbolTable( - SymbolTableError::NotFound(name.to_owned()), - ))? { - Symbol::Local(symbol) => Source::Memory(Memory { - effective_address: EffectiveAddress { - base: Base::Register(self.rbp), - index: None, - scale: None, - displacement: Some(symbol.offset.clone()), - }, - size: self.size(&symbol.type_, scope), - }), - Symbol::Global(symbol) => Source::Memory(Memory { - effective_address: EffectiveAddress { - base: Base::Label(name.to_owned()), - index: None, - scale: None, - displacement: None, - }, - size: self.size(&symbol.type_, scope), - }), - Symbol::Param(symbol) => Source::Memory(Memory { - effective_address: EffectiveAddress { - base: Base::Register(self.rbp), - index: None, - scale: None, - displacement: Some(symbol.offset.clone()), - }, - size: self.size(&symbol.type_, scope), - }), - Symbol::Function(_) => Source::Immediate(Immediate::Label(name.to_owned())), - }, - ) - } - - fn pop(&mut self, dest: &Destination) { - self.buf.push_str(&formatdoc!( - " - \tpop {dest} - " - )); - } -} - -impl Amd64 { - /// Transform variable size into iterator or sizes which can be used for `mov` - /// For example if you use value 11 you will have iterator wil values [8, 2, 1] - fn size_iter(mut size: usize) -> impl Iterator { - let mut sizes = Vec::new(); - while size > 0 { - let chunk_size = match size { - 8.. => 8, - 4..=7 => 4, - 2..=3 => 2, - 1 => 1, - 0 => unreachable!(), - }; - - sizes.push(chunk_size); - size -= chunk_size; - } - - sizes.into_iter() - } -} - -impl std::fmt::Display for operands::Register { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", self.register.from_size(self.size)) - } -} - -impl std::fmt::Display for operands::Immediate { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Self::Int(int) => write!(f, "{int}"), - Self::UInt(uint) => write!(f, "{uint}"), - Self::Label(label) => write!(f, "OFFSET {label}"), - } - } -} - -impl std::fmt::Display for operands::Memory { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!( - f, - "{} {}", - Amd64::size_name(self.size), - self.effective_address - ) - } -} - -impl std::fmt::Display for operands::Base { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Self::Register(register) => write!(f, "{}", register.qword()), - Self::Label(label) => write!(f, "{label}"), - } - } -} - -impl std::fmt::Display for operands::EffectiveAddress { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let mut str = format!("[{}", self.base); - - if let Some(index) = &self.index { - str.push_str(&format!(" + {}", index.qword())); - } - - if let Some(scale) = self.scale { - str.push_str(&format!("* {scale}")); - } - - if let Some(displacement) = &self.displacement { - str.push_str(&format!("{displacement}")); - } - - write!(f, "{}]", str) - } -} - -impl std::fmt::Display for operands::Source { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Self::Memory(memory) => write!(f, "{memory}"), - Self::Register(register) => write!(f, "{register}"), - Self::Immediate(immediate) => write!(f, "{immediate}"), - } - } -} - -impl std::fmt::Display for operands::Destination { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Self::Memory(effective_address) => write!(f, "{effective_address}"), - Self::Register(register) => write!(f, "{register}"), - } - } -} - -#[cfg(test)] -mod test { - pub const RBP: Register = Register::new("there's no one byte one, hmmmm", "bp", "ebp", "rbp"); - use super::Amd64; - use crate::{ - archs::Architecture, - codegen::operands::{ - self, Base, Destination, EffectiveAddress, Immediate, Memory, Offset, Source, - }, - register::Register, - }; - - #[test] - fn mov_literal() { - let r = Register::new("r15b", "r15w", "r15d", "r15"); - let tests = vec![ - ( - ( - Destination::Memory(Memory { - effective_address: EffectiveAddress { - base: Base::Label("foo".to_string()), - index: None, - scale: None, - displacement: Some(Offset(-5)), - }, - size: 4, - }), - Immediate::UInt(15_000), - ), - "\tmov dword ptr [foo - 5], 15000\n", - ), - ( - ( - Destination::Memory(Memory { - effective_address: EffectiveAddress { - base: Base::Label("foo".to_string()), - index: None, - scale: None, - displacement: None, - }, - size: 8, - }), - Immediate::Int(-5), - ), - "\tmov qword ptr [foo], -5\n", - ), - ( - ( - Destination::Memory(Memory { - effective_address: EffectiveAddress { - base: Base::Register(RBP), - index: None, - scale: None, - displacement: Some(Offset(-1)), - }, - size: 4, - }), - Immediate::UInt(5), - ), - "\tmov dword ptr [rbp - 1], 5\n", - ), - ( - ( - Destination::Register(operands::Register { - register: r, - size: 8, - }), - Immediate::UInt(5), - ), - "\tmov r15, 5\n", - ), - ( - ( - Destination::Memory(Memory { - effective_address: EffectiveAddress { - base: Base::Register(r), - index: None, - scale: None, - displacement: Some(Offset(-15)), - }, - size: 8, - }), - Immediate::UInt(5), - ), - "\tmov qword ptr [r15 - 15], 5\n", - ), - ( - ( - Destination::Memory(Memory { - effective_address: EffectiveAddress { - base: Base::Register(r), - index: None, - scale: None, - displacement: Some(Offset(8)), - }, - size: 2, - }), - Immediate::Int(-7), - ), - "\tmov word ptr [r15 + 8], -7\n", - ), - ]; - - for ((dest, immidiate), expected) in tests { - let mut arch = Amd64::new(); - arch.mov(&Source::Immediate(immidiate), &dest, false) - .unwrap(); - - assert_eq!(arch.buf, expected); - } - } - - #[test] - fn mov_register() { - let r = Register::new("r15b", "r15w", "r15d", "r15"); - let r2 = Register::new("r14b", "r14w", "r14d", "r14"); - let tests = vec![ - ( - ( - Destination::Memory(Memory { - effective_address: EffectiveAddress { - base: Base::Label("foo".to_string()), - index: None, - scale: None, - displacement: Some(Offset(-5)), - }, - size: 4, - }), - operands::Register { - register: r, - size: 4, - }, - false, - ), - "\tmov dword ptr [foo - 5], r15d\n", - ), - ( - ( - Destination::Memory(Memory { - effective_address: EffectiveAddress { - base: Base::Label("foo".to_string()), - index: None, - scale: None, - displacement: None, - }, - size: 8, - }), - operands::Register { - register: r, - size: 4, - }, - false, - ), - "\tmov eax, r15d\n\tmov qword ptr [foo], rax\n", - ), - ( - ( - Destination::Memory(Memory { - effective_address: EffectiveAddress { - base: Base::Label("foo".to_string()), - index: None, - scale: None, - displacement: None, - }, - size: 8, - }), - operands::Register { - register: r, - size: 4, - }, - true, - ), - "\tmov eax, r15d\n\tcdqe\n\tmov qword ptr [foo], rax\n", - ), - ( - ( - Destination::Memory(Memory { - effective_address: EffectiveAddress { - base: Base::Register(RBP), - index: None, - scale: None, - displacement: Some(Offset(-10)), - }, - size: 1, - }), - operands::Register { - register: r, - size: 1, - }, - true, - ), - "\tmov byte ptr [rbp - 10], r15b\n", - ), - ( - ( - Destination::Memory(Memory { - effective_address: EffectiveAddress { - base: Base::Register(r2), - index: None, - scale: None, - displacement: Some(Offset(10)), - }, - size: 2, - }), - operands::Register { - register: r, - size: 1, - }, - true, - ), - "\tmovsx word ptr [r14 + 10], r15b\n", - ), - ( - ( - Destination::Register(operands::Register { - register: r2, - size: 8, - }), - operands::Register { - register: r, - size: 4, - }, - false, - ), - "\tmov eax, r15d\n\tmov r14, rax\n", - ), - ( - ( - Destination::Register(operands::Register { - size: 1, - register: r2, - }), - operands::Register { - register: r, - size: 1, - }, - false, - ), - "\tmov r14b, r15b\n", - ), - ]; - - for ((dest, r, signed), expected) in tests { - let mut arch = Amd64::new(); - arch.mov(&Source::Register(r), &dest, signed).unwrap(); - - assert_eq!(arch.buf, expected); - } - } -} diff --git a/src/archs/amd64/mod.rs b/src/archs/amd64/mod.rs deleted file mode 100644 index 55cf2a6..0000000 --- a/src/archs/amd64/mod.rs +++ /dev/null @@ -1,3 +0,0 @@ -mod amd64; - -pub use amd64::Amd64; diff --git a/src/archs/arch.rs b/src/archs/arch.rs deleted file mode 100644 index 353d4dc..0000000 --- a/src/archs/arch.rs +++ /dev/null @@ -1,160 +0,0 @@ -use super::ArchError; -use crate::{ - codegen::{Argument, Destination, EffectiveAddress, Source}, - parser::{BitwiseOp, Block, CmpOp}, - register::{allocator::AllocatorError, Register}, - scope::Scope, - type_table as tt, - types::Type, -}; - -pub enum Jump { - Unconditional, - Equal, - NotEqual, - GreaterThan, - GreaterEqual, - LessThan, - LessEqual, -} - -pub trait ArchitectureClone { - fn clone_box(&self) -> Arch; -} - -pub trait Architecture: ArchitectureClone { - fn new() -> Self - where - Self: Sized; - fn word_size(&self) -> usize; - fn stack_alignment(&self) -> usize; - fn size(&self, type_: &Type, scope: &Scope) -> usize; - fn struct_size(&self, type_struct: &tt::TypeStruct, scope: &Scope) -> usize { - let mut offset: usize = 0; - let mut largest = 0; - - for type_ in type_struct.types() { - let size = self.size(type_, scope); - - offset = offset.next_multiple_of(size); - offset += size; - - if size > largest { - largest = size; - } - } - - // Align to the largest element in the struct - if largest > 0 { - offset.next_multiple_of(largest) - } else { - 0 - } - } - fn alloc(&mut self) -> Result; - fn free(&mut self, register: Register) -> Result<(), AllocatorError>; - fn size_name(size: usize) -> &'static str - where - Self: Sized; - fn declare(&mut self, name: &str, size: usize); - fn mov(&mut self, src: &Source, dest: &Destination, signed: bool) -> Result<(), ArchError>; - fn negate(&mut self, dest: &Destination); - fn add( - &mut self, - lhs: &Source, - rhs: &Source, - dest: &Destination, - signed: bool, - ) -> Result<(), ArchError>; - fn sub( - &mut self, - lhs: &Source, - rhs: &Source, - dest: &Destination, - signed: bool, - ) -> Result<(), ArchError>; - fn mul( - &mut self, - lhs: &Source, - rhs: &Source, - dest: &Destination, - signed: bool, - ) -> Result<(), ArchError>; - fn div( - &mut self, - lhs: &Source, - rhs: &Source, - dest: &Destination, - signed: bool, - ) -> Result<(), ArchError>; - fn bitwise( - &mut self, - lhs: &Source, - rhs: &Source, - dest: &Destination, - op: BitwiseOp, - signed: bool, - ) -> Result<(), ArchError>; - fn bitwise_not(&mut self, dest: &Destination); - fn cmp(&mut self, dest: &Destination, src: &Source); - fn setcc(&mut self, dest: &Destination, condition: CmpOp); - fn fn_preamble( - &mut self, - name: &str, - params: &[Type], - stackframe: usize, - scope: &Scope, - ) -> Result<(), ArchError>; - fn fn_postamble(&mut self, name: &str, stackframe: usize); - fn ret(&mut self, src: &Source, signed: bool) -> Result<(), ArchError>; - fn jcc(&mut self, label: &str, kind: Jump); - fn call( - &mut self, - src: &Source, - dest: Option<&Destination>, - signed: bool, - size: usize, - ) -> Result<(), ArchError>; - fn push_arg(&mut self, src: Source, type_: &Type, preceding: &[Type]) -> Argument; - fn populate_offsets( - &mut self, - block: &mut Block, - scope: &Scope, - start_offset: isize, - ) -> Result; - fn lea(&mut self, dest: &Destination, address: &EffectiveAddress); - fn shrink_stack(&mut self, size: usize); - fn generate_label(&mut self) -> String; - fn write_label(&mut self, label: &str); - fn define_literal(&mut self, literal: String) -> String; - fn array_offset( - &mut self, - base: &Source, - index: &Source, - size: usize, - dest: &Destination, - ) -> Result<(), ArchError>; - fn shl(&mut self, dest: &Destination, src: &Source) -> Result<(), ArchError>; - fn shr(&mut self, dest: &Destination, src: &Source) -> Result<(), ArchError>; - fn push(&mut self, src: &Source); - fn pop(&mut self, dest: &Destination); - fn symbol_source(&self, name: &str, scope: &Scope) -> Result; - fn finish(&mut self) -> Vec; -} - -pub type Arch = Box; - -impl ArchitectureClone for T -where - T: 'static + Architecture + Clone, -{ - fn clone_box(&self) -> Arch { - Box::new(self.clone()) - } -} - -impl Clone for Arch { - fn clone(&self) -> Self { - self.clone_box() - } -} diff --git a/src/archs/error.rs b/src/archs/error.rs deleted file mode 100644 index 8018b62..0000000 --- a/src/archs/error.rs +++ /dev/null @@ -1,17 +0,0 @@ -use crate::{ - parser::ExprError, register::allocator::AllocatorError, symbol_table::SymbolTableError, - types::TypeError, -}; -use thiserror::Error; - -#[derive(Error, Debug)] -pub enum ArchError { - #[error(transparent)] - Type(#[from] TypeError), - #[error(transparent)] - Allocator(#[from] AllocatorError), - #[error(transparent)] - SymbolTable(#[from] SymbolTableError), - #[error(transparent)] - Expr(#[from] ExprError), -} diff --git a/src/archs/mod.rs b/src/archs/mod.rs deleted file mode 100644 index 717547a..0000000 --- a/src/archs/mod.rs +++ /dev/null @@ -1,8 +0,0 @@ -mod arch; -mod error; - -pub use arch::{Arch, Architecture, Jump}; -pub use error::ArchError; - -mod amd64; -pub use amd64::*; diff --git a/src/register/allocator/allocator.rs b/src/codegen/amd64_asm/allocator.rs similarity index 66% rename from src/register/allocator/allocator.rs rename to src/codegen/amd64_asm/allocator.rs index 5ce0503..85b98bc 100644 --- a/src/register/allocator/allocator.rs +++ b/src/codegen/amd64_asm/allocator.rs @@ -1,5 +1,15 @@ -use super::AllocatorError; -use crate::register::Register; +use super::{register::Register, OperandSize}; +use thiserror::Error; + +#[derive(Error, Debug)] +pub enum Error { + #[error("Register was double freed")] + DoubleFree, + #[error("Ran out of registers, whoops!")] + RanOutOfRegisters, + #[error("Register {0} is already in use")] + AlreadyInUse(Register), +} #[derive(Debug, Clone)] pub struct RegisterAllocator { @@ -15,29 +25,29 @@ impl RegisterAllocator { }; } - pub fn alloc(&mut self) -> Result { + pub fn alloc(&mut self, size: OperandSize) -> Result { for (i, reg) in self.registers.iter().enumerate() { if !self.used.contains(&i.try_into().unwrap()) { self.used.push(i.try_into().unwrap()); - return Ok(reg.clone()); + return Ok(reg.resize(size)); } } - Err(AllocatorError::RanOutOfRegisters) + Err(Error::RanOutOfRegisters) } - pub fn alloc_nth(&mut self, n: usize) -> Result { + pub fn alloc_nth(&mut self, n: usize) -> Result { if self.used.contains(&n) { - Err(AllocatorError::AlreadyInUse(self.registers[n])) + Err(Error::AlreadyInUse(self.registers[n].clone())) } else { self.used.push(n); - Ok(self.registers[n]) + Ok(self.registers[n].clone()) } } - pub fn free(&mut self, r: Register) -> Result<(), AllocatorError> { + pub fn free(&mut self, r: Register) -> Result<(), Error> { for (i, register) in self.registers.iter().enumerate() { if &r == register { if let Some(i) = self.used.iter().position(|el| el == &i) { @@ -45,7 +55,7 @@ impl RegisterAllocator { break; } else { - return Err(AllocatorError::DoubleFree); + return Err(Error::DoubleFree); } } } diff --git a/src/codegen/amd64_asm/mod.rs b/src/codegen/amd64_asm/mod.rs new file mode 100644 index 0000000..6d07bea --- /dev/null +++ b/src/codegen/amd64_asm/mod.rs @@ -0,0 +1,863 @@ +mod allocator; +mod operand; +mod register; + +use super::Codegen; +use crate::{ + ir::{Block, Expr, ExprKind, ExprLit, Id, Item, ItemFn, Stmt, Ty, Variable}, + parser::{BinOp, BitwiseOp, CmpOp, OpParseError, UnOp}, + Context, +}; +use allocator::RegisterAllocator; +use derive_more::derive::Display; +use indoc::formatdoc; +use operand::{ + Base, Destination, EffectiveAddress, Immediate, ImmediateStrLitError, Memory, Source, +}; +use register::Register; +use std::collections::HashMap; +use thiserror::Error; + +struct LabelGenerator(usize); + +impl LabelGenerator { + pub fn new() -> Self { + Self(0) + } + + pub fn generate(&mut self) -> String { + let str = format!(".L{}", self.0); + + self.0 += 1; + + str + } +} + +#[derive(Error, Debug)] +#[error("{0} is not a valid operand size")] +pub struct InvalidOperandSizeError(usize); + +#[derive(Copy, Clone, Debug, PartialEq, PartialOrd, Display)] +pub enum OperandSize { + #[display("byte ptr")] + Byte, + #[display("word ptr")] + Word, + #[display("dword ptr")] + Dword, + #[display("qword ptr")] + Qword, +} + +impl TryFrom for OperandSize { + type Error = InvalidOperandSizeError; + + fn try_from(value: usize) -> Result { + Ok(match value { + 1 => Self::Byte, + 2 => Self::Word, + 4 => Self::Dword, + 8 => Self::Qword, + _ => return Err(InvalidOperandSizeError(value)), + }) + } +} + +#[derive(Debug, Error)] +pub enum Amd64AsmError { + #[error(transparent)] + InvalidOperandSize(#[from] InvalidOperandSizeError), + #[error(transparent)] + RegisterAllocator(#[from] allocator::Error), + #[error(transparent)] + OpParse(#[from] OpParseError), + #[error(transparent)] + ImmediateStrLit(#[from] ImmediateStrLitError), +} + +impl std::fmt::Display for CmpOp { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "{}", + match self { + Self::LessThan => "setl", + Self::LessEqual => "setle", + Self::GreaterThan => "setg", + Self::GreaterEqual => "setge", + Self::Equal => "sete", + Self::NotEqual => "setne", + } + ) + } +} + +impl std::fmt::Display for BitwiseOp { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "{}", + match self { + Self::And => "and", + Self::Or => "or", + } + ) + } +} + +#[derive(Display)] +pub enum Jump { + #[display("jmp")] + Unconditional, + #[display("je")] + Equal, + #[display("jne")] + NotEqual, + #[display("jg")] + GreaterThan, + #[display("jge")] + GreaterEqual, + #[display("jl")] + LessThan, + #[display("jle")] + LessEqual, +} + +pub struct Amd64Asm<'a, 'ir> { + ctx: &'a Context<'ir>, + allocator: RegisterAllocator, + label_gen: LabelGenerator, + bss: String, + data: String, + text: String, + stack_offset: isize, + variables: HashMap, +} + +impl<'a, 'ir> Codegen<'a, 'ir> for Amd64Asm<'a, 'ir> { + fn new(ctx: &'a Context<'ir>) -> Self { + Self { + ctx, + label_gen: LabelGenerator::new(), + allocator: RegisterAllocator::new(vec![ + Register::R15, + Register::R14, + Register::R13, + Register::R12, + Register::R11, + Register::R10, + Register::R9, + Register::R8, + Register::Rcx, + Register::Rdx, + Register::Rsi, + Register::Rdi, + ]), + bss: String::new(), + data: String::new(), + text: String::new(), + stack_offset: 0, + variables: HashMap::new(), + } + } + + fn compile(&mut self) -> Result, Box> { + for item in self.ctx.ir.iter_items() { + self.item(&item)?; + } + + let mut result = String::new(); + + if !self.bss.is_empty() { + result.push_str(".section .bss\n"); + result.push_str(&self.bss); + } + if !self.data.is_empty() { + result.push_str(".section .data\n"); + result.push_str(&self.data); + } + if !self.text.is_empty() { + result.push_str(".section .text\n"); + result.push_str(&self.text); + } + + Ok(result.into_bytes()) + } +} + +impl<'a, 'ir> Amd64Asm<'a, 'ir> { + const BITNESS: usize = 64; + + fn expr_dest(&mut self, expr: &Expr) -> Result { + Ok(match expr.kind { + ExprKind::Ident(id) => self.variables[&id].clone(), + ExprKind::Unary(op, expr) if op == UnOp::Deref => { + let ty = self.ctx.resolve_ty(expr.ty); + let ty_size = self.ty_size(ty).try_into()?; + let r = self.allocator.alloc(ty_size)?; + + self.expr(expr, Some(&r.into()))?; + + Destination::Memory(Memory { + effective_address: EffectiveAddress { + base: Base::Register(r), + index: None, + scale: None, + displacement: None, + }, + size: ty_size, + }) + } + expr => unreachable!("{expr:?} is not a valid lvalue expression"), + }) + } + + fn item(&mut self, item: &Item) -> Result<(), Amd64AsmError> { + match item { + Item::Fn(item) => self.function(item), + Item::Global(item) => { + self.global(item)?; + + Ok(()) + } + } + } + + fn function(&mut self, item: &ItemFn) -> Result<(), Amd64AsmError> { + let name = item.name; + + self.text.push_str(&formatdoc!( + " + .global {name} + {name}: + " + )); + + let stack_frame = self.stack_frame_size(&item.block); + + if stack_frame > 0 { + let rbp = &Register::Rbp.into(); + self.push(rbp); + self.mov(&Register::Rsp.into(), &Register::Rbp.into(), false)?; + self.sub( + &Register::Rsp.into(), + &(stack_frame as u64).into(), + &Register::Rsp.into(), + false, + )?; + } + + item.block + .0 + .iter() + .map(|stmt| self.stmt(stmt)) + .collect::>()?; + + let ret_label = &self.label_gen.generate(); + self.write_label(&ret_label); + + if stack_frame > 0 { + self.text.push_str(&format!("\tleave\n")); + } + + self.text.push_str(&format!("\tret\n")); + self.stack_offset = 0; + self.variables.clear(); + + Ok(()) + } + + fn stmt(&mut self, stmt: &Stmt) -> Result<(), Amd64AsmError> { + match stmt { + Stmt::Local(stmt) => self.local(stmt), + Stmt::Return(expr) => self.ret(expr.as_ref()), + Stmt::Item(item) => self.item(item), + Stmt::Expr(expr) => self.expr(expr, None), + } + } + + fn expr(&mut self, expr: &Expr, dest: Option<&Destination>) -> Result<(), Amd64AsmError> { + Ok(match expr.kind { + ExprKind::Binary(op, lhs, rhs) => { + macro_rules! eval_expr { + // This shit looks scary + ($(($out: ident, $expr: ident)),+) => { + $( + let $out = self.allocator.alloc(self.ty_size(self.ctx.resolve_ty($expr.ty)).try_into()?)?; + self.expr($expr, Some(&$out.into()))?; + )+ + }; + } + + let signed = self.ctx.resolve_ty(expr.ty).signed(); + + match op { + BinOp::Assign => { + let expr_dest = self.expr_dest(lhs)?; + + self.expr(rhs, Some(&expr_dest))?; + + if let Some(dest) = dest { + self.mov(&expr_dest.into(), dest, signed)?; + } + } + BinOp::Add => { + if let Some(dest) = dest { + eval_expr!((r_lhs, lhs), (r_rhs, rhs)); + + self.add(&r_lhs.into(), &r_rhs.into(), dest, signed)?; + + self.allocator.free(r_lhs)?; + self.allocator.free(r_rhs)?; + } + } + BinOp::Sub => { + if let Some(dest) = dest { + eval_expr!((r_lhs, lhs), (r_rhs, rhs)); + + self.sub(&r_lhs.into(), &r_rhs.into(), dest, signed)?; + + self.allocator.free(r_lhs)?; + self.allocator.free(r_rhs)?; + } + } + BinOp::Mul => { + if let Some(dest) = dest { + eval_expr!((r_lhs, lhs), (r_rhs, rhs)); + + self.mul(&r_lhs.into(), &r_rhs.into(), dest, signed)?; + + self.allocator.free(r_lhs)?; + self.allocator.free(r_rhs)?; + } + } + BinOp::Div => { + if let Some(dest) = dest { + eval_expr!((r_lhs, lhs), (r_rhs, rhs)); + + self.div(&r_lhs.into(), &r_rhs.into(), dest, signed)?; + + self.allocator.free(r_lhs)?; + self.allocator.free(r_rhs)?; + } + } + BinOp::Equal + | BinOp::NotEqual + | BinOp::LessThan + | BinOp::LessEqual + | BinOp::GreaterThan + | BinOp::GreaterEqual => { + if let Some(dest) = dest { + eval_expr!((r_lhs, lhs), (r_rhs, rhs)); + + self.cmp(&r_lhs.into(), &r_rhs.into()); + self.setcc(dest, CmpOp::try_from(&op)?); + + self.allocator.free(r_lhs)?; + self.allocator.free(r_rhs)?; + } + } + BinOp::LogicalOr => { + if let Some(dest) = dest { + self.logical_or(expr, dest, None)? + } + } + BinOp::LogicalAnd => { + if let Some(dest) = dest { + self.logical_and(expr, dest, None)? + } + } + BinOp::BitwiseOr | BinOp::BitwiseAnd => { + if let Some(dest) = dest { + eval_expr!((r_lhs, lhs), (r_rhs, rhs)); + + self.bitwise( + &r_lhs.into(), + &r_rhs.into(), + dest, + BitwiseOp::try_from(&op)?, + self.ctx.resolve_ty(expr.ty).signed(), + )?; + + self.allocator.free(r_lhs)?; + self.allocator.free(r_rhs)?; + } + } + BinOp::Shl => { + if let Some(dest) = dest { + eval_expr!((r_lhs, lhs), (r_rhs, rhs)); + + self.mov(&r_lhs.into(), dest, lhs.ty.signed())?; + self.shl(dest, &r_rhs.into())?; + + self.allocator.free(r_lhs)?; + self.allocator.free(r_rhs)?; + } + } + BinOp::Shr => { + if let Some(dest) = dest { + eval_expr!((r_lhs, lhs), (r_rhs, rhs)); + + self.mov(&r_lhs.into(), dest, lhs.ty.signed())?; + self.shr(dest, &r_rhs.into())?; + + self.allocator.free(r_lhs)?; + self.allocator.free(r_rhs)?; + } + } + }; + } + ExprKind::Unary(op, inner_expr) => { + if let Some(dest) = dest { + let signed = self.ctx.resolve_ty(inner_expr.ty).signed(); + + match op { + UnOp::LogicalNot => { + self.expr(inner_expr, Some(dest))?; + self.cmp(dest, &Source::Immediate(Immediate::UInt(0))); + self.setcc(dest, CmpOp::Equal); + } + UnOp::Negative => { + self.expr(inner_expr, Some(dest))?; + self.negate(dest); + } + UnOp::Address => { + let expr_dest = self.expr_dest(inner_expr)?; + let r = self.allocator.alloc(OperandSize::Qword)?; + + self.lea(&r.into(), &expr_dest.into()); + self.mov(&r.into(), dest, signed)?; + + self.allocator.free(r)?; + } + UnOp::Deref => { + let expr_dest = self.expr_dest(expr)?; + + self.mov(&expr_dest.into(), dest, signed)?; + } + UnOp::BitwiseNot => { + self.expr(inner_expr, Some(dest))?; + self.bitwise_not(dest); + } + } + } + } + ExprKind::Lit(lit) => { + if let Some(dest) = dest { + if let ExprLit::String(str) = lit { + let label = self.define_str_literal(str); + + self.mov(&Source::Immediate(label.into()), dest, false)?; + } else { + let signed = if let ExprLit::Int(_) = lit { + true + } else { + false + }; + + self.mov(&Source::Immediate(lit.try_into()?), dest, signed)?; + } + } + } + ExprKind::Ident(id) => { + if let Some(dest) = dest { + let expr = self.variables.get(&id).unwrap(); + + self.mov(&expr.clone().into(), dest, false)?; + } + } + }) + } + + fn global(&mut self, item: &Variable) -> Result<(), Amd64AsmError> { + Ok(()) + } + + fn local(&mut self, stmt: &Variable) -> Result<(), Amd64AsmError> { + let size = self.ty_size(self.ctx.resolve_ty(stmt.ty)); + + self.stack_offset -= size as isize; + let dest = Destination::Memory(Memory { + effective_address: Register::Rbp.into_effective_addr(self.stack_offset), + size: size.try_into()?, + }); + + self.variables.insert( + stmt.id, + Destination::Memory(Memory { + effective_address: Register::Rbp.into_effective_addr(self.stack_offset), + size: size.try_into()?, + }), + ); + + if let Some(expr) = stmt.initializer { + self.expr(&expr, Some(&dest))?; + } + + Ok(()) + } + + fn define_str_literal(&mut self, literal: &str) -> String { + let label = self.label_gen.generate(); + + self.data.push_str(&formatdoc!( + " + {}: + .string \"{}\" + ", + label, + literal + )); + + label + } + + fn ret(&mut self, expr: Option<&Expr>) -> Result<(), Amd64AsmError> { + if let Some(expr) = expr { + let ty = self.ctx.resolve_ty(expr.ty); + let r = self.allocator.alloc(self.ty_size(ty).try_into()?)?; + let dest = r.clone().into(); + + self.expr(expr, Some(&dest))?; + self.mov(&dest.into(), &Register::Rax.into(), false)?; + self.allocator.free(r)?; + } + + Ok(()) + } + + fn mov(&mut self, src: &Source, dest: &Destination, signed: bool) -> Result<(), Amd64AsmError> { + match (dest, src) { + (dest @ Destination::Memory(_), src @ Source::Memory(_)) => { + let r = self.allocator.alloc(dest.size()).unwrap(); + + self.mov(src, &r.clone().into(), signed)?; + self.mov(&r.clone().into(), dest, signed)?; + + self.allocator.free(r)?; + } + (dest, src) => { + let dest_size = dest.size(); + let src_size = src.size().unwrap_or_else(|| OperandSize::Qword); + + if dest_size == OperandSize::Qword && src_size == OperandSize::Dword { + // On x86_64 you can move 32bit value in 32bit register, and upper 32bits of the register will be zeroed + self.mov(src, &Register::Eax.into(), false)?; + + if signed { + self.text.push_str("\tcdqe\n"); + } + + self.mov(&Register::Rax.into(), dest, false)?; + } else if dest_size > src_size { + if signed { + self.text.push_str(&format!("\tmovsx {dest}, {src}\n")); + } else { + self.text.push_str(&format!("\tmovzx {dest}, {src}\n")); + } + } else { + self.text.push_str(&format!("\tmov {dest}, {src}\n")); + } + } + } + + Ok(()) + } + + fn push(&mut self, src: &Source) { + self.text.push_str(&format!("\tpush {src}\n")); + } + + fn pop(&mut self, dest: &Destination) { + self.text.push_str(&format!("\tpop {dest}\n")); + } + + fn ty_size(&self, ty: &Ty) -> usize { + ty.size(Self::BITNESS) + } + + fn stack_frame_size(&self, block: &Block) -> usize { + let mut size = 0; + + for stmt in block.0 { + match stmt { + Stmt::Local(stmt) => { + size += self.ty_size(self.ctx.resolve_ty(stmt.ty)); + } + Stmt::Item(_) | Stmt::Expr(_) | Stmt::Return(_) => (), + } + } + + size + } + + fn add( + &mut self, + lhs: &Source, + rhs: &Source, + dest: &Destination, + signed: bool, + ) -> Result<(), Amd64AsmError> { + lhs.size().map(|size| assert!(size == dest.size())); + rhs.size().map(|size| assert!(size == dest.size())); + assert!(!(lhs == dest && rhs == dest)); + + let lhs = if let Source::Immediate(_) = lhs { + self.mov(lhs, &Register::Rax.resize(dest.size()).into(), signed)?; + &Register::Rax.resize(dest.size()).into() + } else { + lhs + }; + + self.text.push_str(&format!("\tadd {lhs}, {rhs}\n")); + + if lhs != dest { + self.mov(lhs, dest, signed)?; + } + + Ok(()) + } + + fn sub( + &mut self, + lhs: &Source, + rhs: &Source, + dest: &Destination, + signed: bool, + ) -> Result<(), Amd64AsmError> { + lhs.size().map(|size| assert!(size == dest.size())); + rhs.size().map(|size| assert!(size == dest.size())); + assert!(!(lhs == dest && rhs == dest)); + + let lhs = if let Source::Immediate(_) = lhs { + self.mov(lhs, &Register::Rax.resize(dest.size()).into(), signed)?; + &Register::Rax.resize(dest.size()).into() + } else { + lhs + }; + + self.text.push_str(&format!("\tsub {lhs}, {rhs}\n")); + + if lhs != dest { + self.mov(lhs, dest, signed)?; + } + + Ok(()) + } + + fn mul( + &mut self, + lhs: &Source, + rhs: &Source, + dest: &Destination, + signed: bool, + ) -> Result<(), Amd64AsmError> { + lhs.size().map(|size| assert_eq!(size, dest.size())); + rhs.size().map(|size| assert_eq!(size, dest.size())); + assert!(!(lhs == dest && rhs == dest)); + + self.mov( + lhs, + &lhs.size() + .map_or(Register::Rax, |size| Register::Rax.resize(size)) + .into(), + signed, + )?; + + if rhs != dest { + self.mov(rhs, dest, signed)?; + } + if self.allocator.is_used(&Register::Rdx) { + self.push(&Register::Rdx.into()); + } + + self.text.push_str(&format!("\timul {dest}\n")); + + if self.allocator.is_used(&Register::Rdx) { + self.pop(&Register::Rdx.into()); + } + + self.mov(&Register::Rax.resize(dest.size()).into(), dest, signed)?; + + Ok(()) + } + + //NOTE: if mafs isn't mafsing, probably because of this + fn div( + &mut self, + lhs: &Source, + rhs: &Source, + dest: &Destination, + signed: bool, + ) -> Result<(), Amd64AsmError> { + lhs.size().map(|size| assert_eq!(size, dest.size())); + rhs.size().map(|size| assert_eq!(size, dest.size())); + assert!(!(lhs == dest && rhs == dest)); + + self.mov(lhs, &Register::Rax.into(), signed)?; + self.mov(rhs, dest, signed)?; + if self.allocator.is_used(&Register::Rdx) { + self.push(&Register::Rdx.into()); + } + + self.text.push_str(&formatdoc!( + " + \tcqo + \tidiv {dest} + ", + )); + + if self.allocator.is_used(&Register::Rdx) { + self.pop(&Register::Rdx.into()); + } + + self.mov(&Register::Rax.resize(dest.size()).into(), dest, signed)?; + + Ok(()) + } + + fn setcc(&mut self, dest: &Destination, condition: CmpOp) { + self.text.push_str(&format!("\t{condition} {dest}\n")); + } + + fn cmp(&mut self, dest: &Destination, src: &Source) { + self.text.push_str(&format!("\tcmp {dest}, {src}\n")); + } + + fn jcc(&mut self, label: &str, kind: Jump) { + self.text.push_str(&format!("\t{kind} {label}\n")); + } + + fn write_label(&mut self, label: &str) { + self.text.push_str(&format!("{label}:\n")); + } + + fn logical_or( + &mut self, + expr: &Expr, + dest: &Destination, + labels: Option<(String, String)>, + ) -> Result<(), Amd64AsmError> { + let mut empty = false; + let (set, end) = labels.unwrap_or_else(|| { + empty = true; + + (self.label_gen.generate(), self.label_gen.generate()) + }); + + match expr.kind { + ExprKind::Binary(BinOp::LogicalOr, lhs, rhs) => { + self.logical_or(lhs, dest, Some((set.clone(), end.clone())))?; + self.logical_or(rhs, dest, Some((set.clone(), end.clone())))?; + } + _ => { + self.expr(expr, Some(dest))?; + self.cmp(dest, &Source::Immediate(Immediate::UInt(0))); + self.jcc(&set, Jump::NotEqual); + } + }; + + if empty { + self.mov(&Source::Immediate(Immediate::UInt(0)), dest, false)?; + self.jcc(&end, Jump::Unconditional); + self.write_label(&set); + self.mov(&Source::Immediate(Immediate::UInt(1)), dest, false)?; + self.write_label(&end); + } + + Ok(()) + } + + fn logical_and( + &mut self, + expr: &Expr, + dest: &Destination, + labels: Option<(String, String)>, + ) -> Result<(), Amd64AsmError> { + let mut empty = false; + let (set, end) = labels.unwrap_or_else(|| { + empty = true; + + (self.label_gen.generate(), self.label_gen.generate()) + }); + + match expr.kind { + ExprKind::Binary(BinOp::LogicalAnd, lhs, rhs) => { + self.logical_and(lhs, dest, Some((set.clone(), end.clone())))?; + self.logical_and(rhs, dest, Some((set.clone(), end.clone())))?; + } + _ => { + self.expr(expr, Some(dest))?; + self.cmp(dest, &Source::Immediate(Immediate::UInt(0))); + self.jcc(&set, Jump::Equal); + } + }; + + if empty { + self.mov(&Source::Immediate(Immediate::UInt(1)), dest, false)?; + self.jcc(&end, Jump::Unconditional); + self.write_label(&set); + self.mov(&Source::Immediate(Immediate::UInt(0)), dest, false)?; + self.write_label(&end); + } + + Ok(()) + } + + fn bitwise( + &mut self, + lhs: &Source, + rhs: &Source, + dest: &Destination, + op: BitwiseOp, + signed: bool, + ) -> Result<(), Amd64AsmError> { + lhs.size().map(|size| assert_eq!(size, dest.size())); + rhs.size().map(|size| assert_eq!(size, dest.size())); + assert!(!(lhs == dest && rhs == dest)); + + let lhs = if let Source::Immediate(_) = lhs { + self.mov(lhs, &Register::Rax.resize(dest.size()).into(), signed)?; + &Register::Rax.into() + } else { + lhs + }; + + self.text.push_str(&format!("\t{op} {rhs}\n")); + + if lhs != dest { + self.mov(lhs, dest, signed)?; + } + + Ok(()) + } + + fn shl(&mut self, dest: &Destination, src: &Source) -> Result<(), Amd64AsmError> { + self.mov(src, &Register::Rcx.into(), false)?; + self.text + .push_str(&format!("\tshl {dest}, {}\n", Register::Cl)); + + Ok(()) + } + + fn shr(&mut self, dest: &Destination, src: &Source) -> Result<(), Amd64AsmError> { + self.mov(src, &Register::Rcx.into(), false)?; + self.text + .push_str(&format!("\tshr {dest}, {}\n", Register::Cl)); + + Ok(()) + } + + fn lea(&mut self, dest: &Destination, address: &EffectiveAddress) { + self.text.push_str(&format!("\tlea {dest}, {address}\n")); + } + + fn negate(&mut self, dest: &Destination) { + self.text.push_str(&format!("\tneg {dest}\n")); + } + + fn bitwise_not(&mut self, dest: &Destination) { + self.text.push_str(&format!("\tnot {dest}\n")); + } +} diff --git a/src/codegen/amd64_asm/operand.rs b/src/codegen/amd64_asm/operand.rs new file mode 100644 index 0000000..06f1204 --- /dev/null +++ b/src/codegen/amd64_asm/operand.rs @@ -0,0 +1,230 @@ +use super::{register::Register, OperandSize}; +use crate::ir::ExprLit; +use derive_more::derive::Display; +use thiserror::Error; + +#[derive(Debug, Clone, PartialEq, Default)] +pub struct Offset(pub isize); + +impl std::ops::Add for &Offset { + type Output = Offset; + + fn add(self, rhs: isize) -> Self::Output { + Offset(self.0 + rhs) + } +} + +impl std::ops::Sub for &Offset { + type Output = Offset; + + fn sub(self, rhs: isize) -> Self::Output { + Offset(self.0 - rhs) + } +} + +impl std::ops::Add<&Offset> for &Offset { + type Output = Offset; + + fn add(self, rhs: &Offset) -> Self::Output { + Offset(self.0 + rhs.0) + } +} + +impl std::ops::Sub<&Offset> for &Offset { + type Output = Offset; + + fn sub(self, rhs: &Offset) -> Self::Output { + Offset(self.0 - rhs.0) + } +} + +impl std::fmt::Display for Offset { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + if self.0 > 0 { + write!(f, " + {}", self.0) + } else if self.0 < 0 { + write!(f, " - {}", self.0.abs()) + } else { + write!(f, "") + } + } +} + +#[derive(Clone, Debug, PartialEq, Display)] +pub enum Base { + Register(Register), + Label(String), +} + +#[derive(Clone, Debug, PartialEq)] +pub struct EffectiveAddress { + pub base: Base, + pub index: Option, + pub scale: Option, + pub displacement: Option, +} + +impl std::fmt::Display for EffectiveAddress { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let mut str = format!("[{}", self.base); + + if let Some(index) = &self.index { + str.push_str(&format!(" + {}", index)); + } + + if let Some(scale) = self.scale { + str.push_str(&format!("* {scale}")); + } + + if let Some(displacement) = &self.displacement { + str.push_str(&format!("{displacement}")); + } + + write!(f, "{}]", str) + } +} + +#[derive(Debug, Clone, Display)] +pub enum Immediate { + Int(i64), + UInt(u64), + Label(String), +} + +impl Into for String { + fn into(self) -> Immediate { + Immediate::Label(self) + } +} + +#[derive(Debug, Error)] +#[error("String literal is not a valid assembly immediate")] +pub struct ImmediateStrLitError; + +impl TryInto for ExprLit<'_> { + type Error = ImmediateStrLitError; + + fn try_into(self) -> Result { + Ok(match self { + Self::Int(lit) => Immediate::Int(lit), + Self::UInt(lit) => Immediate::UInt(lit), + Self::Bool(lit) => Immediate::UInt(lit.into()), + Self::Null => Immediate::UInt(0), + Self::String(_) => return Err(ImmediateStrLitError), + }) + } +} + +#[derive(Clone, Debug, PartialEq)] +pub struct Memory { + pub effective_address: EffectiveAddress, + pub size: OperandSize, +} + +impl std::fmt::Display for Memory { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{} {}", self.size, self.effective_address) + } +} + +#[derive(Debug, Clone)] +pub enum Source { + Memory(Memory), + Register(Register), + Immediate(Immediate), +} + +impl Source { + pub fn size(&self) -> Option { + match self { + Self::Memory(mem) => Some(mem.size.clone()), + Self::Register(reg) => Some(reg.size()), + Self::Immediate(_) => None, + } + } +} + +impl std::fmt::Display for Source { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Memory(mem) => mem.fmt(f), + Self::Register(r) => r.fmt(f), + Self::Immediate(imm) => imm.fmt(f), + } + } +} + +impl Into for i64 { + fn into(self) -> Source { + Source::Immediate(Immediate::Int(self)) + } +} + +impl Into for u64 { + fn into(self) -> Source { + Source::Immediate(Immediate::UInt(self)) + } +} + +#[derive(Clone, Debug, PartialEq)] +pub enum Destination { + Memory(Memory), + Register(Register), +} + +impl Destination { + pub fn size(&self) -> OperandSize { + match self { + Self::Memory(mem) => mem.size.clone(), + Self::Register(reg) => reg.size(), + } + } +} + +impl Into for Destination { + fn into(self) -> Source { + match self { + Self::Memory(mem) => Source::Memory(mem), + Self::Register(r) => Source::Register(r), + } + } +} + +impl Into for Destination { + fn into(self) -> EffectiveAddress { + match self { + Self::Memory(mem) => mem.effective_address, + Self::Register(r) => EffectiveAddress { + base: Base::Register(r), + index: None, + scale: None, + displacement: None, + }, + } + } +} + +impl std::fmt::Display for Destination { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Memory(mem) => mem.fmt(f), + Self::Register(r) => r.fmt(f), + } + } +} + +impl PartialEq<&Source> for &Destination { + fn eq(&self, other: &&Source) -> bool { + match (self, other) { + (Destination::Memory(lhs), Source::Memory(rhs)) if lhs == rhs => true, + (Destination::Register(lhs), Source::Register(rhs)) if lhs == rhs => true, + _ => false, + } + } +} + +impl PartialEq<&Destination> for &Source { + fn eq(&self, other: &&Destination) -> bool { + other == self + } +} diff --git a/src/codegen/amd64_asm/register.rs b/src/codegen/amd64_asm/register.rs new file mode 100644 index 0000000..79b1594 --- /dev/null +++ b/src/codegen/amd64_asm/register.rs @@ -0,0 +1,324 @@ +use super::{ + operand::{Base, Destination, EffectiveAddress, Offset, Source}, + OperandSize, +}; +use derive_more::derive::Display; + +#[derive(Display, Debug, Copy, Clone, PartialEq)] +pub enum Register { + #[display("al")] + Al, + #[display("ax")] + Ax, + #[display("eax")] + Eax, + #[display("rax")] + Rax, + + #[display("bpl")] + Bpl, + #[display("bp")] + Bp, + #[display("ebp")] + Ebp, + #[display("rbp")] + Rbp, + + #[display("cl")] + Cl, + #[display("cx")] + Cx, + #[display("ecx")] + Ecx, + #[display("rcx")] + Rcx, + + #[display("dl")] + Dl, + #[display("dx")] + Dx, + #[display("edx")] + Edx, + #[display("rdx")] + Rdx, + + #[display("sil")] + Sil, + #[display("si")] + Si, + #[display("esi")] + Esi, + #[display("rsi")] + Rsi, + + #[display("dil")] + Dil, + #[display("di")] + Di, + #[display("edi")] + Edi, + #[display("rdi")] + Rdi, + + #[display("rsp")] + Rsp, + #[display("esp")] + Esp, + #[display("sp")] + Sp, + #[display("spl")] + Spl, + + #[display("r15b")] + R15b, + #[display("r15w")] + R15w, + #[display("r15d")] + R15d, + #[display("r15")] + R15, + + #[display("r14b")] + R14b, + #[display("r14w")] + R14w, + #[display("r14d")] + R14d, + #[display("r14")] + R14, + + #[display("r13b")] + R13b, + #[display("r13w")] + R13w, + #[display("r13d")] + R13d, + #[display("r13")] + R13, + + #[display("r12b")] + R12b, + #[display("r12w")] + R12w, + #[display("r12d")] + R12d, + #[display("r12")] + R12, + + #[display("r11b")] + R11b, + #[display("r11w")] + R11w, + #[display("r11d")] + R11d, + #[display("r11")] + R11, + + #[display("r10b")] + R10b, + #[display("r10w")] + R10w, + #[display("r10d")] + R10d, + #[display("r10")] + R10, + + #[display("r9b")] + R9b, + #[display("r9w")] + R9w, + #[display("r9d")] + R9d, + #[display("r9")] + R9, + + #[display("r8b")] + R8b, + #[display("r8w")] + R8w, + #[display("r8d")] + R8d, + #[display("r8")] + R8, +} + +impl Register { + pub fn resize(self, size: OperandSize) -> Self { + match (self, size) { + (Self::Al | Self::Ax | Self::Eax | Self::Rax, OperandSize::Byte) => Self::Al, + (Self::Al | Self::Ax | Self::Eax | Self::Rax, OperandSize::Word) => Self::Ax, + (Self::Al | Self::Ax | Self::Eax | Self::Rax, OperandSize::Dword) => Self::Eax, + (Self::Al | Self::Ax | Self::Eax | Self::Rax, OperandSize::Qword) => Self::Rax, + + (Self::Bpl | Self::Bp | Self::Ebp | Self::Rbp, OperandSize::Byte) => Self::Bpl, + (Self::Bpl | Self::Bp | Self::Ebp | Self::Rbp, OperandSize::Word) => Self::Bp, + (Self::Bpl | Self::Bp | Self::Ebp | Self::Rbp, OperandSize::Dword) => Self::Ebp, + (Self::Bpl | Self::Bp | Self::Ebp | Self::Rbp, OperandSize::Qword) => Self::Rbp, + + (Self::Cl | Self::Cx | Self::Ecx | Self::Rcx, OperandSize::Byte) => Self::Cl, + (Self::Cl | Self::Cx | Self::Ecx | Self::Rcx, OperandSize::Word) => Self::Cx, + (Self::Cl | Self::Cx | Self::Ecx | Self::Rcx, OperandSize::Dword) => Self::Ecx, + (Self::Cl | Self::Cx | Self::Ecx | Self::Rcx, OperandSize::Qword) => Self::Rcx, + + (Self::Dl | Self::Dx | Self::Edx | Self::Rdx, OperandSize::Byte) => Self::Dl, + (Self::Dl | Self::Dx | Self::Edx | Self::Rdx, OperandSize::Word) => Self::Dx, + (Self::Dl | Self::Dx | Self::Edx | Self::Rdx, OperandSize::Dword) => Self::Edx, + (Self::Dl | Self::Dx | Self::Edx | Self::Rdx, OperandSize::Qword) => Self::Rdx, + + (Self::Sil | Self::Si | Self::Esi | Self::Rsi, OperandSize::Byte) => Self::Sil, + (Self::Sil | Self::Si | Self::Esi | Self::Rsi, OperandSize::Word) => Self::Si, + (Self::Sil | Self::Si | Self::Esi | Self::Rsi, OperandSize::Dword) => Self::Esi, + (Self::Sil | Self::Si | Self::Esi | Self::Rsi, OperandSize::Qword) => Self::Rsi, + + (Self::Dil | Self::Di | Self::Edi | Self::Rdi, OperandSize::Byte) => Self::Dil, + (Self::Dil | Self::Di | Self::Edi | Self::Rdi, OperandSize::Word) => Self::Di, + (Self::Dil | Self::Di | Self::Edi | Self::Rdi, OperandSize::Dword) => Self::Edi, + (Self::Dil | Self::Di | Self::Edi | Self::Rdi, OperandSize::Qword) => Self::Rdi, + + (Self::Spl | Self::Sp | Self::Esp | Self::Rsp, OperandSize::Byte) => Self::Spl, + (Self::Spl | Self::Sp | Self::Esp | Self::Rsp, OperandSize::Word) => Self::Sp, + (Self::Spl | Self::Sp | Self::Esp | Self::Rsp, OperandSize::Dword) => Self::Esp, + (Self::Spl | Self::Sp | Self::Esp | Self::Rsp, OperandSize::Qword) => Self::Rsp, + + (Self::R15b | Self::R15w | Self::R15d | Self::R15, OperandSize::Byte) => Self::R15b, + (Self::R15b | Self::R15w | Self::R15d | Self::R15, OperandSize::Word) => Self::R15w, + (Self::R15b | Self::R15w | Self::R15d | Self::R15, OperandSize::Dword) => Self::R15d, + (Self::R15b | Self::R15w | Self::R15d | Self::R15, OperandSize::Qword) => Self::R15, + + (Self::R14b | Self::R14w | Self::R14d | Self::R14, OperandSize::Byte) => Self::R14b, + (Self::R14b | Self::R14w | Self::R14d | Self::R14, OperandSize::Word) => Self::R14w, + (Self::R14b | Self::R14w | Self::R14d | Self::R14, OperandSize::Dword) => Self::R14d, + (Self::R14b | Self::R14w | Self::R14d | Self::R14, OperandSize::Qword) => Self::R14, + + (Self::R13b | Self::R13w | Self::R13d | Self::R13, OperandSize::Byte) => Self::R13b, + (Self::R13b | Self::R13w | Self::R13d | Self::R13, OperandSize::Word) => Self::R13w, + (Self::R13b | Self::R13w | Self::R13d | Self::R13, OperandSize::Dword) => Self::R13d, + (Self::R13b | Self::R13w | Self::R13d | Self::R13, OperandSize::Qword) => Self::R13, + + (Self::R12b | Self::R12w | Self::R12d | Self::R12, OperandSize::Byte) => Self::R12b, + (Self::R12b | Self::R12w | Self::R12d | Self::R12, OperandSize::Word) => Self::R12w, + (Self::R12b | Self::R12w | Self::R12d | Self::R12, OperandSize::Dword) => Self::R12d, + (Self::R12b | Self::R12w | Self::R12d | Self::R12, OperandSize::Qword) => Self::R12, + + (Self::R11b | Self::R11w | Self::R11d | Self::R11, OperandSize::Byte) => Self::R11b, + (Self::R11b | Self::R11w | Self::R11d | Self::R11, OperandSize::Word) => Self::R11w, + (Self::R11b | Self::R11w | Self::R11d | Self::R11, OperandSize::Dword) => Self::R11d, + (Self::R11b | Self::R11w | Self::R11d | Self::R11, OperandSize::Qword) => Self::R11, + + (Self::R10b | Self::R10w | Self::R10d | Self::R10, OperandSize::Byte) => Self::R10b, + (Self::R10b | Self::R10w | Self::R10d | Self::R10, OperandSize::Word) => Self::R10w, + (Self::R10b | Self::R10w | Self::R10d | Self::R10, OperandSize::Dword) => Self::R10d, + (Self::R10b | Self::R10w | Self::R10d | Self::R10, OperandSize::Qword) => Self::R10, + + (Self::R9b | Self::R9w | Self::R9d | Self::R9, OperandSize::Byte) => Self::R9b, + (Self::R9b | Self::R9w | Self::R9d | Self::R9, OperandSize::Word) => Self::R9w, + (Self::R9b | Self::R9w | Self::R9d | Self::R9, OperandSize::Dword) => Self::R9d, + (Self::R9b | Self::R9w | Self::R9d | Self::R9, OperandSize::Qword) => Self::R9, + + (Self::R8b | Self::R8w | Self::R8d | Self::R8, OperandSize::Byte) => Self::R8b, + (Self::R8b | Self::R8w | Self::R8d | Self::R8, OperandSize::Word) => Self::R8w, + (Self::R8b | Self::R8w | Self::R8d | Self::R8, OperandSize::Dword) => Self::R8d, + (Self::R8b | Self::R8w | Self::R8d | Self::R8, OperandSize::Qword) => Self::R8, + } + } + + pub fn size(&self) -> OperandSize { + match self { + Self::Al => OperandSize::Byte, + Self::Ax => OperandSize::Word, + Self::Eax => OperandSize::Dword, + Self::Rax => OperandSize::Qword, + + Self::Bpl => OperandSize::Byte, + Self::Bp => OperandSize::Word, + Self::Ebp => OperandSize::Dword, + Self::Rbp => OperandSize::Qword, + + Self::Cl => OperandSize::Byte, + Self::Cx => OperandSize::Word, + Self::Ecx => OperandSize::Dword, + Self::Rcx => OperandSize::Qword, + + Self::Dl => OperandSize::Byte, + Self::Dx => OperandSize::Word, + Self::Edx => OperandSize::Dword, + Self::Rdx => OperandSize::Qword, + + Self::Sil => OperandSize::Byte, + Self::Si => OperandSize::Word, + Self::Esi => OperandSize::Dword, + Self::Rsi => OperandSize::Qword, + + Self::Dil => OperandSize::Byte, + Self::Di => OperandSize::Word, + Self::Edi => OperandSize::Dword, + Self::Rdi => OperandSize::Qword, + + Self::Spl => OperandSize::Byte, + Self::Sp => OperandSize::Word, + Self::Esp => OperandSize::Dword, + Self::Rsp => OperandSize::Qword, + + Self::R15b => OperandSize::Byte, + Self::R15w => OperandSize::Word, + Self::R15d => OperandSize::Dword, + Self::R15 => OperandSize::Qword, + + Self::R14b => OperandSize::Byte, + Self::R14w => OperandSize::Word, + Self::R14d => OperandSize::Dword, + Self::R14 => OperandSize::Qword, + + Self::R13b => OperandSize::Byte, + Self::R13w => OperandSize::Word, + Self::R13d => OperandSize::Dword, + Self::R13 => OperandSize::Qword, + + Self::R12b => OperandSize::Byte, + Self::R12w => OperandSize::Word, + Self::R12d => OperandSize::Dword, + Self::R12 => OperandSize::Qword, + + Self::R11b => OperandSize::Byte, + Self::R11w => OperandSize::Word, + Self::R11d => OperandSize::Dword, + Self::R11 => OperandSize::Qword, + + Self::R10b => OperandSize::Byte, + Self::R10w => OperandSize::Word, + Self::R10d => OperandSize::Dword, + Self::R10 => OperandSize::Qword, + + Self::R9b => OperandSize::Byte, + Self::R9w => OperandSize::Word, + Self::R9d => OperandSize::Dword, + Self::R9 => OperandSize::Qword, + + Self::R8b => OperandSize::Byte, + Self::R8w => OperandSize::Word, + Self::R8d => OperandSize::Dword, + Self::R8 => OperandSize::Qword, + } + } + + pub fn into_effective_addr(self, displacement: isize) -> EffectiveAddress { + EffectiveAddress { + base: Base::Register(self), + index: None, + scale: None, + displacement: Some(Offset(displacement)), + } + } +} + +impl Into for Register { + fn into(self) -> Destination { + Destination::Register(self) + } +} + +impl Into for Register { + fn into(self) -> Source { + Source::Register(self) + } +} diff --git a/src/codegen/argument.rs b/src/codegen/argument.rs deleted file mode 100644 index 53b717b..0000000 --- a/src/codegen/argument.rs +++ /dev/null @@ -1,6 +0,0 @@ -use crate::register::Register; - -pub enum Argument { - Register(Register), - Stack(usize), -} diff --git a/src/codegen/codegen.rs b/src/codegen/codegen.rs deleted file mode 100644 index 688f310..0000000 --- a/src/codegen/codegen.rs +++ /dev/null @@ -1,1249 +0,0 @@ -use super::{ - operands, Argument, CodeGenError, Destination, EffectiveAddress, Immediate, Offset, - SethiUllman, Source, -}; -use crate::{ - archs::{Arch, Jump}, - parser::{ - BinOp, BitwiseOp, CmpOp, Expr, ExprArray, ExprArrayAccess, ExprBinary, ExprFunctionCall, - ExprIdent, ExprLit, ExprStruct, ExprStructAccess, ExprStructMethod, ExprUnary, Expression, - Stmt, StmtFor, StmtFunction, StmtIf, StmtReturn, StmtVarDecl, StmtWhile, UnOp, - }, - register::Register, - scope::Scope, - symbol_table::Symbol, - type_table as tt, - types::{Type, TypeError}, -}; -use operands::{Base, Memory}; - -enum ScopeInfo { - Function { label: String }, - Loop { start: String, end: String }, -} - -#[derive(Debug, Clone)] -pub struct State { - false_label: Option, - end_label: Option, -} - -pub struct CodeGen { - pub arch: Arch, - pub scope: Scope, - scope_infos: Vec, -} - -impl CodeGen { - pub fn new(arch: Arch, scope: Scope) -> Self { - Self { - arch, - scope, - scope_infos: Vec::new(), - } - } - - fn declare(&mut self, variable: StmtVarDecl) -> Result<(), CodeGenError> { - if !self.scope.local() { - self.arch - .declare(&variable.name, self.arch.size(&variable.type_, &self.scope)); - } - - if let Some(expr) = variable.value { - self.expr( - Expr::Binary(ExprBinary { - op: BinOp::Assign, - left: Box::new(Expr::Ident(ExprIdent(variable.name))), - right: Box::new(expr), - }), - None, - None, - )?; - } - - Ok(()) - } - - fn function(&mut self, mut func: StmtFunction) -> Result<(), CodeGenError> { - let offset = self - .arch - .populate_offsets(&mut func.block, &self.scope, Default::default())? - .unsigned_abs() - .next_multiple_of(self.arch.stack_alignment()); - - self.scope.enter(func.block.scope); - self.arch.fn_preamble( - &func.name, - &func - .params - .iter() - .map(|(_, type_)| type_.to_owned()) - .collect::>(), - offset, - &self.scope, - )?; - self.scope_infos.push(ScopeInfo::Function { - label: func.name.clone(), - }); - - for stmt in func.block.statements { - self.stmt(stmt)?; - } - - self.scope_infos.pop(); - self.arch.fn_postamble(&func.name, offset); - self.scope.leave(); - - Ok(()) - } - - fn ret(&mut self, ret: StmtReturn) -> Result<(), CodeGenError> { - if let Some(expr) = ret.expr { - let type_ = expr.type_(&self.scope)?; - let r = self.arch.alloc()?; - - self.expr( - expr, - Some(&r.dest(self.arch.size(&type_, &self.scope))), - None, - )?; - self.arch.ret( - &Source::Register(operands::Register { - register: r, - size: self.arch.size(&type_, &self.scope), - }), - type_.signed(), - )?; - self.arch.free(r)?; - } - - if let Some(label) = self.function_scope_info() { - self.arch.jcc(&format!("{label}_ret"), Jump::Unconditional); - - Ok(()) - } else { - unreachable!(); - } - } - - fn if_stmt(&mut self, if_stmt: StmtIf) -> Result<(), CodeGenError> { - let r = self.arch.alloc()?; - let expr_size = self - .arch - .size(&if_stmt.condition.type_(&self.scope)?, &self.scope); - - self.expr(if_stmt.condition, Some(&r.dest(expr_size)), None)?; - self.arch.cmp( - &Destination::Register(operands::Register { - register: r, - size: expr_size, - }), - &Source::Immediate(Immediate::UInt(0)), - ); - self.arch.free(r)?; - - let consequence_label = self.arch.generate_label(); - let alternative_label = self.arch.generate_label(); - - self.arch.jcc(&alternative_label, Jump::Equal); - self.scope.enter(if_stmt.consequence.scope); - for stmt in if_stmt.consequence.statements { - self.stmt(stmt)?; - } - self.scope.leave(); - - if if_stmt.alternative.is_some() { - self.arch.jcc(&consequence_label, Jump::Unconditional); - } - - self.arch.write_label(&alternative_label); - if let Some(block) = if_stmt.alternative { - self.scope.enter(block.scope); - - for stmt in block.statements { - self.stmt(stmt)?; - } - - self.scope.leave(); - self.arch.write_label(&consequence_label); - } - - Ok(()) - } - - fn while_stmt(&mut self, stmt: StmtWhile) -> Result<(), CodeGenError> { - let start_label = self.arch.generate_label(); - let end_label = self.arch.generate_label(); - let r = self.arch.alloc()?; - - self.arch.write_label(&start_label); - self.expr(stmt.condition, Some(&r.dest(1)), None)?; - self.arch.cmp( - &Destination::Register(operands::Register { - register: r, - size: 1, - }), - &Source::Immediate(Immediate::UInt(0)), - ); - self.arch.jcc(&end_label, Jump::Equal); - self.scope_infos.push(ScopeInfo::Loop { - start: start_label.clone(), - end: end_label.clone(), - }); - - self.scope.enter(stmt.block.scope); - for stmt in stmt.block.statements { - self.stmt(stmt)?; - } - self.scope.leave(); - - self.scope_infos.pop(); - self.arch.jcc(&start_label, Jump::Unconditional); - self.arch.write_label(&end_label); - - self.arch.free(r)?; - - Ok(()) - } - - fn for_stmt(&mut self, stmt: StmtFor) -> Result<(), CodeGenError> { - let start_label = self.arch.generate_label(); - let end_label = self.arch.generate_label(); - let increment_label = self.arch.generate_label(); - let r = self.arch.alloc()?; - - self.scope.enter(stmt.block.scope); - - if let Some(initializer) = stmt.initializer { - self.stmt(*initializer)?; - } - - self.arch.write_label(&start_label); - if let Some(condition) = stmt.condition { - self.expr(condition, Some(&r.dest(1)), None)?; - self.arch.cmp( - &Destination::Register(operands::Register { - register: r, - size: 1, - }), - &Source::Immediate(Immediate::UInt(0)), - ); - self.arch.jcc(&end_label, Jump::Equal); - } - - self.scope_infos.push(ScopeInfo::Loop { - start: increment_label.clone(), - end: end_label.clone(), - }); - - for stmt in stmt.block.statements { - self.stmt(stmt)?; - } - - self.scope_infos.pop(); - - self.arch.write_label(&increment_label); - if let Some(increment) = stmt.increment { - self.expr(increment, None, None)?; - } - - self.scope.leave(); - - self.arch.jcc(&start_label, Jump::Unconditional); - self.arch.write_label(&end_label); - - self.arch.free(r)?; - - Ok(()) - } - - pub fn expr( - &mut self, - expr: Expr, - dest: Option<&Destination>, - state: Option<&State>, - ) -> Result<(), CodeGenError> { - match expr { - Expr::Binary(bin_expr) => self.bin_expr(bin_expr, dest, state)?, - Expr::Lit(lit) => { - if let Some(dest) = dest { - if let ExprLit::String(literal) = &lit { - let label = self.arch.define_literal(literal.to_owned()); - - self.arch - .mov(&Source::Immediate(Immediate::Label(label)), dest, false)?; - } else { - let signed = lit.signed(); - self.arch - .mov(&Source::Immediate(lit.into()), dest, signed)? - } - } - } - Expr::Unary(unary_expr) => { - if let Some(dest) = dest { - self.unary_expr(unary_expr, dest, state)? - } - } - Expr::Ident(ident) => { - if let Some(dest) = dest { - let source = self.arch.symbol_source(&ident.0, &self.scope)?; - - // If the ident is of type array, the address of variable has to be moved, not the value - if let Type::Array(_) = ident.type_(&self.scope)? { - let r = self.arch.alloc()?; - let r_op = operands::Register { - register: r, - size: self.arch.word_size(), - }; - - self.arch - .lea(&Destination::Register(r_op.clone()), &source.clone().into()); - self.arch.mov(&Source::Register(r_op), dest, false)?; - self.arch.free(r)?; - } else { - self.arch.mov(&source, dest, false)?; - } - - source.free(&mut self.arch)?; - } - } - Expr::Cast(cast_expr) => { - if let Some(dest) = dest { - let type_ = cast_expr.expr.type_(&self.scope)?; - let og_size = if let Type::Array(_) = &type_ { - self.arch.word_size() - } else { - self.arch.size(&type_, &self.scope) - }; - assert!(og_size <= 8); - let casted_size = self.arch.size(&cast_expr.type_, &self.scope); - - if casted_size != og_size { - let (r, new) = match dest { - Destination::Memory(_) => (self.arch.alloc()?, true), - Destination::Register(register) => (register.register, false), - }; - - self.expr(*cast_expr.expr, Some(&r.dest(og_size)), state)?; - - if casted_size > og_size { - self.arch.mov( - &r.source(og_size), - &r.dest(casted_size), - type_.signed(), - )?; - } - - if new { - self.arch - .mov(&r.source(casted_size), dest, type_.signed())?; - self.arch.free(r)?; - } - } else { - self.expr(*cast_expr.expr, Some(dest), state)?; - } - } - } - Expr::FunctionCall(func_call) => self.call_function(func_call, dest, state)?, - Expr::Struct(expr) => { - if let Some(dest) = dest { - self.struct_expr(expr, dest, state)? - } - } - Expr::Array(expr) => { - if let Some(dest) = dest { - self.array_expr(expr, dest, state)? - } - } - Expr::StructAccess(expr) => { - if let Some(dest) = dest { - self.struct_access(expr, dest)?; - } - } - Expr::StructMethod(expr) => { - self.struct_call_method(expr, dest, state)?; - } - Expr::ArrayAccess(expr) => { - if let Some(dest) = dest { - self.array_access(expr, dest)?; - } - } - Expr::MacroCall(_) => unreachable!(), - }; - - Ok(()) - } - - fn bin_expr( - &mut self, - mut expr: ExprBinary, - dest: Option<&Destination>, - state: Option<&State>, - ) -> Result<(), CodeGenError> { - let type_ = expr.type_(&self.scope)?; - let left_type = expr.left.type_(&self.scope)?; - let right_type = expr.right.type_(&self.scope)?; - let size = self.arch.size(&type_, &self.scope); - let signed = type_.signed(); - - match &expr.op { - BinOp::Assign => { - let lvalue_dest = self.expr_dest(*expr.left)?; - - self.expr(*expr.right, Some(&lvalue_dest), state)?; - - if let Some(dest) = dest { - self.arch.mov(&lvalue_dest.clone().into(), dest, signed)?; - } - - self.free(lvalue_dest)?; - } - BinOp::Add - | BinOp::Sub - | BinOp::Mul - | BinOp::Div - | BinOp::BitwiseAnd - | BinOp::BitwiseOr => { - if let Some(dest) = dest { - let expr_src = |codegen: &mut Self, - expr: Expr, - size: usize, - dest: Option, - state: Option<&State>| - -> Result<(Source, Option), CodeGenError> { - let (dest, r) = match dest{ - Some(dest) => (dest, None), - None => { - let r =codegen.arch.alloc()?; - - (r.dest(size), Some(r)) - } - }; - - Ok(match expr { - Expr::Lit(lit) => (Source::Immediate(lit.into()),r), - _ => { - codegen.expr(expr, Some(&dest), state)?; - - (dest.into(), r) - } - }) - }; - - if expr.op == BinOp::Add { - // Canonicalize foo + T* to T* + foo - if type_.ptr() && expr.right.type_(&self.scope).unwrap().ptr() { - std::mem::swap(expr.left.as_mut(), expr.right.as_mut()); - } - } - - let (expr_dest, dest_r) = match dest { - Destination::Memory(_) => { - let r = self.arch.alloc()?; - - (r.dest(size), Some(r)) - } - Destination::Register(register) if register.size != size => { - (dest.to_owned().with_size(size), None) - } - _ => (dest.to_owned(), None), - }; - - let (lhs, rhs, expr_r) = if expr.left.r_num() > expr.right.r_num() { - let (lhs, _) = - expr_src(self, *expr.left, size, Some(expr_dest.clone()), state)?; - let (rhs, r) = expr_src(self, *expr.right, size, None, state)?; - - (lhs, rhs, r) - } else { - let (rhs, _) = - expr_src(self, *expr.right, size, Some(expr_dest.clone()), state)?; - let (lhs, r) = expr_src(self, *expr.left, size, None, state)?; - - (lhs, rhs, r) - }; - - match &expr.op { - BinOp::Add => match (left_type, right_type) { - (Type::Ptr(type_), _) => { - let r = self.arch.alloc()?; - - self.arch.mul( - &Source::Immediate(Immediate::UInt( - self.arch.size(&type_, &self.scope) as u64, - )), - &rhs, - &r.dest(size), - signed, - )?; - self.arch.add(&lhs, &r.source(size), &expr_dest, false)?; - - self.arch.free(r)?; - } - _ => self.arch.add(&lhs, &rhs, &expr_dest, signed)?, - }, - BinOp::Sub => match (left_type, right_type) { - // type_ and _ are the same - (Type::Ptr(type_), Type::Ptr(_)) => { - let r = self.arch.alloc()?; - - self.arch.sub(&lhs, &rhs, &expr_dest, signed)?; - self.arch.div( - &expr_dest.clone().into(), - &Source::Immediate(Immediate::UInt( - self.arch.size(&type_, &self.scope) as u64, - )), - &r.dest(size), - true, - )?; - self.arch.mov(&r.source(size), &expr_dest, true)?; - - self.arch.free(r)?; - } - (Type::Ptr(type_), _) => { - let r = self.arch.alloc()?; - - self.arch.mul( - &Source::Immediate(Immediate::UInt( - self.arch.size(&type_, &self.scope) as u64, - )), - &rhs, - &r.dest(size), - signed, - )?; - self.arch.sub(&lhs, &r.source(size), &expr_dest, false)?; - - self.arch.free(r)?; - } - _ => self.arch.sub(&lhs, &rhs, &expr_dest, signed)?, - }, - BinOp::Mul => { - self.arch.mul(&lhs, &rhs, &expr_dest, signed)?; - } - BinOp::Div => { - self.arch.div(&lhs, &rhs, &expr_dest, signed)?; - } - BinOp::BitwiseAnd | BinOp::BitwiseOr => { - self.arch.bitwise( - &lhs, - &rhs, - &expr_dest, - BitwiseOp::try_from(&expr.op).unwrap(), - signed, - )?; - } - _ => unreachable!(), - }; - - if dest.size() != expr_dest.size() { - self.arch.mov( - &expr_dest.clone().into(), - &expr_dest.clone().with_size(dest.size()), - signed, - )?; - } - if let Destination::Memory(_) = dest { - self.arch.mov( - &expr_dest.clone().with_size(dest.size()).into(), - dest, - signed, - )?; - } - if let Some(r) = dest_r { - self.arch.free(r)?; - } - if let Some(r) = expr_r { - self.arch.free(r)?; - } - } - } - BinOp::LessThan - | BinOp::LessEqual - | BinOp::GreaterThan - | BinOp::GreaterEqual - | BinOp::Equal - | BinOp::NotEqual => { - if let Some(dest) = dest { - let size = self - .arch - .size(&Type::common_type(left_type, right_type), &self.scope); - let left = self.arch.alloc()?; - let right = self.arch.alloc()?; - - self.expr(*expr.left, Some(&left.dest(size)), state)?; - self.expr(*expr.right, Some(&right.dest(size)), state)?; - self.arch.cmp(&left.dest(size), &right.source(size)); - self.arch.setcc(dest, CmpOp::try_from(&expr.op)?); - self.arch.free(left)?; - self.arch.free(right)?; - } - } - BinOp::Shl => { - if let Some(dest) = dest { - let r = self.arch.alloc()?; - - self.expr(*expr.left, Some(dest), state)?; - self.expr(*expr.right, Some(&r.dest(size)), state)?; - - self.arch.shl( - dest, - &Source::Register(operands::Register { register: r, size }), - )?; - - self.arch.free(r)?; - } - } - BinOp::Shr => { - if let Some(dest) = dest { - let r = self.arch.alloc()?; - - self.expr(*expr.left, Some(dest), state)?; - self.expr(*expr.right, Some(&r.dest(size)), state)?; - - self.arch.shr( - dest, - &Source::Register(operands::Register { register: r, size }), - )?; - - self.arch.free(r)?; - } - } - BinOp::LogicalAnd | BinOp::LogicalOr => { - if let Some(dest) = dest { - let mut parent = false; - let state = state.map(|state| state.to_owned()).unwrap_or_else(|| { - parent = true; - - State { - false_label: Some(self.arch.generate_label()), - end_label: Some(self.arch.generate_label()), - } - }); - let eval = - |codegen: &mut Self, expr: Expr, op: &BinOp| -> Result<(), CodeGenError> { - let cmp = match &expr { - Expr::Binary(ExprBinary { - op: BinOp::LogicalAnd, - .. - }) - | Expr::Binary(ExprBinary { - op: BinOp::LogicalOr, - .. - }) => false, - _ => true, - }; - - let opposite = if op == &BinOp::LogicalAnd { - &BinOp::LogicalOr - } else { - &BinOp::LogicalAnd - }; - - match &expr { - Expr::Binary(ExprBinary { op, .. }) if op == opposite => { - codegen.expr(expr, Some(dest), None)?; - } - _ => { - codegen.expr(expr, Some(dest), Some(&state))?; - } - } - - if cmp { - codegen - .arch - .cmp(dest, &Source::Immediate(Immediate::UInt(0))); - - if op == &BinOp::LogicalAnd { - codegen - .arch - .jcc(&state.false_label.as_ref().unwrap(), Jump::Equal); - } else { - codegen - .arch - .jcc(&state.false_label.as_ref().unwrap(), Jump::NotEqual); - } - } - - Ok(()) - }; - - eval(self, *expr.left, &expr.op)?; - eval(self, *expr.right, &expr.op)?; - - if parent { - if expr.op == BinOp::LogicalAnd { - self.arch - .mov(&Source::Immediate(Immediate::UInt(1)), dest, false)?; - self.arch - .jcc(state.end_label.as_ref().unwrap(), Jump::Unconditional); - self.arch.write_label(state.false_label.as_ref().unwrap()); - self.arch - .mov(&Source::Immediate(Immediate::UInt(0)), dest, false)?; - } else { - self.arch - .mov(&Source::Immediate(Immediate::UInt(0)), dest, false)?; - self.arch - .jcc(state.end_label.as_ref().unwrap(), Jump::Unconditional); - self.arch.write_label(state.false_label.as_ref().unwrap()); - self.arch - .mov(&Source::Immediate(Immediate::UInt(1)), dest, false)?; - } - - self.arch.write_label(state.end_label.as_ref().unwrap()); - } - } - } - }; - - Ok(()) - } - - fn stmt(&mut self, stmt: Stmt) -> Result<(), CodeGenError> { - match stmt { - Stmt::Expr(expr) => self.expr(expr, None, None).map(|_| ()), - Stmt::VarDecl(var_decl) => self.declare(var_decl), - Stmt::Function(func) => self.function(func), - Stmt::Return(ret) => self.ret(ret), - Stmt::If(stmt) => self.if_stmt(stmt), - Stmt::While(stmt) => self.while_stmt(stmt), - Stmt::For(stmt) => self.for_stmt(stmt), - Stmt::Continue => { - if let Some((start, _)) = self.loop_scope_info() { - self.arch.jcc(&start, Jump::Unconditional); - - Ok(()) - } else { - unreachable!(); - } - } - Stmt::Break => { - if let Some((_, end)) = self.loop_scope_info() { - self.arch.jcc(&end, Jump::Unconditional); - - Ok(()) - } else { - unreachable!(); - } - } - } - } - - fn unary_expr( - &mut self, - unary_expr: ExprUnary, - dest: &Destination, - state: Option<&State>, - ) -> Result<(), CodeGenError> { - match unary_expr.op { - UnOp::Negative => { - self.expr(*unary_expr.expr, Some(dest), state)?; - self.arch.negate(dest); - } - UnOp::LogicalNot => { - let type_ = unary_expr.type_(&self.scope)?; - let r = self.arch.alloc()?; - - self.expr( - *unary_expr.expr, - Some(&r.dest(self.arch.size(&type_, &self.scope))), - state, - )?; - self.arch.cmp( - &r.dest(self.arch.size(&type_, &self.scope)), - &Source::Immediate(Immediate::UInt(0)), - ); - self.arch.setcc(dest, CmpOp::Equal); - self.arch.free(r)?; - } - UnOp::Address => { - let signed = unary_expr.expr.type_(&self.scope)?.signed(); - let expr_dest = self.expr_dest(*unary_expr.expr)?; - let r = self.arch.alloc()?; - - self.arch.lea( - &Destination::Register(operands::Register { - register: r, - size: self.arch.word_size(), - }), - &expr_dest.clone().into(), - ); - self.arch.mov( - &Source::Register(operands::Register { - register: r, - size: self.arch.word_size(), - }), - dest, - signed, - )?; - - self.arch.free(r)?; - self.free(expr_dest)?; - } - UnOp::Deref => { - let signed = unary_expr.expr.type_(&self.scope)?.signed(); - let expr_dest = self.expr_dest(Expr::Unary(unary_expr))?; - - self.arch.mov(&expr_dest.clone().into(), dest, signed)?; - self.free(expr_dest)?; - } - UnOp::BitwiseNot => { - self.expr(*unary_expr.expr, Some(dest), state)?; - - self.arch.bitwise_not(dest); - } - }; - - Ok(()) - } - - fn call_function( - &mut self, - call: ExprFunctionCall, - dest: Option<&Destination>, - state: Option<&State>, - ) -> Result<(), CodeGenError> { - let mut preceding = Vec::new(); - let mut stack_size = 0; - let mut arg_registers = Vec::new(); - - for expr in call.arguments.clone().into_iter() { - let type_ = expr.type_(&self.scope)?; - let r = self.arch.alloc()?; - - self.expr(expr, Some(&r.dest(self.arch.word_size())), state)?; - let arg = self.arch.push_arg( - Source::Register(operands::Register { - register: r, - size: self.arch.word_size(), - }), - &type_, - &preceding, - ); - - match arg { - Argument::Stack(size) => stack_size += size, - Argument::Register(r) => arg_registers.push(r), - }; - - self.arch.free(r)?; - preceding.push(type_); - } - - let (_, return_type) = match call.expr.type_(&self.scope)? { - Type::Fn(params, return_type) => (params, return_type), - _ => unreachable!(), - }; - - match *call.expr { - Expr::Ident(expr) - if self.scope.find_symbol(&expr.0).is_some_and(|symbol| { - if let Symbol::Function(_) = symbol { - true - } else { - false - } - }) => - { - self.arch.call( - &Source::Immediate(Immediate::Label(expr.0)), - dest, - return_type.signed(), - self.arch.size(&return_type, &self.scope), - )?; - } - _ => { - let r = self.arch.alloc()?; - - self.expr(*call.expr, Some(&r.dest(self.arch.word_size())), state)?; - self.arch.call( - &r.source(self.arch.word_size()), - dest, - return_type.signed(), - self.arch.size(&return_type, &self.scope), - )?; - self.arch.free(r)?; - } - }; - - for r in arg_registers { - self.arch.free(r)?; - } - if stack_size > 0 { - self.arch.shrink_stack(stack_size); - } - - Ok(()) - } - - fn struct_expr( - &mut self, - expr: ExprStruct, - dest: &Destination, - state: Option<&State>, - ) -> Result<(), CodeGenError> { - let type_struct = match self - .scope - .find_type(&expr.name) - .ok_or(TypeError::Nonexistent(expr.name))? - { - tt::Type::Struct(type_) => type_, - } - .clone(); - //NOTE: clone bad ^ - - for (name, expr) in expr.fields.into_iter() { - let offset = type_struct.offset(&self.arch, &name, &self.scope)?; - let field_size = self - .arch - .size(&type_struct.get_field_type(&name).unwrap(), &self.scope); - - self.expr( - expr, - Some(&match dest.clone() { - Destination::Memory(mut memory) => { - memory.effective_address.displacement = Some( - &memory - .effective_address - .displacement - .unwrap_or(Offset::default()) - + &offset, - ); - memory.size = field_size; - - Destination::Memory(memory) - } - _ => todo!(), - }), - state, - )?; - } - - Ok(()) - } - - fn array_expr( - &mut self, - expr: ExprArray, - dest: &Destination, - state: Option<&State>, - ) -> Result<(), CodeGenError> { - for (i, expr) in expr.0.into_iter().enumerate() { - let size = self.arch.size(&expr.type_(&self.scope)?, &self.scope); - let index = self.arch.alloc()?; - let r = self.arch.alloc()?; - let r_loc = r.dest(self.arch.word_size()); - - self.arch.mov( - &Source::Immediate(Immediate::UInt(i as u64)), - &index.dest(self.arch.word_size()), - false, - )?; - - match dest { - Destination::Memory(memory) => { - self.arch.lea(&r_loc, &memory.effective_address); - } - Destination::Register(_) => unreachable!(), - } - - self.arch.array_offset( - &r_loc.clone().into(), - &index.source(self.arch.word_size()), - size, - &r_loc, - )?; - - self.expr( - expr, - Some(&Destination::Memory(Memory { - effective_address: EffectiveAddress { - base: Base::Register(r), - index: None, - scale: None, - displacement: None, - }, - size, - })), - state, - )?; - self.arch.free(index)?; - self.arch.free(r)?; - } - - Ok(()) - } - - fn struct_access( - &mut self, - expr: ExprStructAccess, - dest: &Destination, - ) -> Result<(), CodeGenError> { - let signed = expr.type_(&self.scope)?.signed(); - let expr_dest = self.expr_dest(Expr::StructAccess(expr))?; - - self.arch.mov(&expr_dest.clone().into(), dest, signed)?; - self.free(expr_dest) - } - - fn struct_call_method( - &mut self, - expr: ExprStructMethod, - dest: Option<&Destination>, - state: Option<&State>, - ) -> Result<(), CodeGenError> { - let struct_name = match expr.expr.type_(&self.scope)? { - Type::Custom(name) => name, - _ => unreachable!(), - }; - let expr_dest = self.expr_dest(*expr.expr)?; - let r = self.arch.alloc()?; - let effective_address = match expr_dest.clone() { - Destination::Memory(memory) => memory.effective_address, - _ => unreachable!(), - }; - - self.arch - .lea(&r.dest(self.arch.word_size()), &effective_address); - - let this = Type::Ptr(Box::new(Type::Custom(struct_name.clone()))); - let mut preceding = Vec::new(); - let mut stack_size = 0; - let mut arg_registers = Vec::new(); - - let arg = self - .arch - .push_arg(r.source(self.arch.word_size()), &this, &preceding); - preceding.push(this); - - match arg { - Argument::Stack(size) => stack_size += size, - Argument::Register(r) => arg_registers.push(r), - }; - - for expr in expr.arguments.into_iter() { - let type_ = expr.type_(&self.scope)?; - let r = self.arch.alloc()?; - - self.expr(expr, Some(&r.dest(self.arch.word_size())), state)?; - let arg = self.arch.push_arg( - Source::Register(operands::Register { - register: r, - size: self.arch.word_size(), - }), - &type_, - &preceding, - ); - - match arg { - Argument::Stack(size) => stack_size += size, - Argument::Register(r) => arg_registers.push(r), - }; - - self.arch.free(r)?; - preceding.push(type_); - } - - let method = match self.scope.find_type(&struct_name).unwrap() { - tt::Type::Struct(type_) => type_.find_method(&expr.method).unwrap(), - }; - self.arch.call( - &Source::Immediate(Immediate::Label(format!("{struct_name}__{}", expr.method))), - dest, - method.return_type.signed(), - self.arch.size(&method.return_type, &self.scope), - )?; - for r in arg_registers { - self.arch.free(r)?; - } - if stack_size > 0 { - self.arch.shrink_stack(stack_size); - } - self.free(expr_dest)?; - self.arch.free(r)?; - - Ok(()) - } - - fn array_access( - &mut self, - expr: ExprArrayAccess, - dest: &Destination, - ) -> Result<(), CodeGenError> { - let signed = expr.type_(&self.scope)?.signed(); - let expr_dest = self.expr_dest(Expr::ArrayAccess(expr))?; - - self.arch.mov(&expr_dest.clone().into(), dest, signed)?; - self.free(expr_dest) - } - - fn free(&mut self, dest: Destination) -> Result<(), CodeGenError> { - Ok(match dest { - Destination::Memory(memory) => { - match memory.effective_address.base { - Base::Register(register) => self.arch.free(register)?, - Base::Label(_) => (), - } - if let Some(index) = memory.effective_address.index { - self.arch.free(index)?; - } - } - Destination::Register(register) => { - self.arch.free(register.register)?; - } - }) - } - - pub fn compile(&mut self, program: Vec) -> Result, CodeGenError> { - for stmt in program { - self.stmt(stmt)?; - } - - Ok(self.arch.finish()) - } - - fn function_scope_info(&self) -> Option { - self.scope_infos.iter().rev().find_map(|info| { - if let ScopeInfo::Function { label } = info { - Some(label.to_owned()) - } else { - None - } - }) - } - - fn loop_scope_info(&self) -> Option<(String, String)> { - self.scope_infos.iter().rev().find_map(|info| { - if let ScopeInfo::Loop { start, end } = info { - Some((start.to_owned(), end.to_owned())) - } else { - None - } - }) - } - - fn expr_dest(&mut self, expr: Expr) -> Result { - match expr { - Expr::Ident(expr) => Ok(self - .arch - .symbol_source(&expr.0, &self.scope)? - .try_into() - .unwrap()), - Expr::StructAccess(expr) => { - let type_ = expr.type_(&self.scope)?; - let (field_offset, mut field_size) = match expr.expr.type_(&self.scope)? { - Type::Custom(c) => { - match self.scope.find_type(&c).ok_or(TypeError::Nonexistent(c))? { - tt::Type::Struct(type_struct) => ( - type_struct.offset(&self.arch, &expr.field, &self.scope)?, - self.arch.size( - &type_struct.get_field_type(&expr.field).unwrap(), - &self.scope, - ), - ), - } - } - type_ => panic!("{type_:?}"), - }; - - if let Type::Array(_) = expr.type_(&self.scope)? { - field_size = self.arch.word_size(); - } - - let dest = match self.expr_dest(*expr.expr)? { - Destination::Memory(mut memory) => { - memory.effective_address.displacement = - if let Some(displacement) = memory.effective_address.displacement { - Some(&displacement + &field_offset) - } else { - Some(field_offset) - }; - memory.size = field_size; - - Destination::Memory(memory) - } - Destination::Register(register) => Destination::Memory(Memory { - effective_address: EffectiveAddress { - base: Base::Register(register.register), - index: None, - scale: None, - displacement: Some(field_offset), - }, - size: field_size, - }), - }; - - //TODO: dis looks ugly, refactor it plz - if let Type::Array(_) = type_ { - let r = self.arch.alloc()?; - - self.arch - .lea(&r.dest(self.arch.word_size()), &dest.clone().into()); - self.free(dest)?; - - Ok(r.dest(self.arch.word_size())) - } else { - Ok(dest) - } - } - Expr::ArrayAccess(expr) => { - let type_ = expr.type_(&self.scope)?; - let base = self.expr_dest(*expr.expr)?; - let index = self.arch.alloc()?; - let r = self.arch.alloc()?; - let r_loc = r.dest(self.arch.word_size()); - - self.expr( - *expr.index.clone(), - Some(&index.dest(self.arch.word_size())), - None, - )?; - match &base { - Destination::Memory(memory) => { - self.arch.lea(&r_loc, &memory.effective_address); - } - Destination::Register(register) => { - self.arch - .mov(&Source::Register(register.to_owned()), &r_loc, false)?; - } - } - self.arch.array_offset( - &r_loc.clone().into(), - &index.source(self.arch.word_size()), - self.arch.size(&type_, &self.scope), - &r_loc, - )?; - - self.free(base)?; - self.arch.free(index)?; - - Ok(Destination::Memory(Memory { - effective_address: EffectiveAddress { - base: Base::Register(r), - index: None, - scale: None, - displacement: None, - }, - size: self.arch.size(&type_, &self.scope), - })) - } - Expr::Unary(expr) if expr.op == UnOp::Deref => { - let type_ = expr.expr.type_(&self.scope)?; - let r = self.arch.alloc()?; - let dest = r.dest(self.arch.size(&type_, &self.scope)); - - self.expr(*expr.expr, Some(&dest), None)?; - - Ok(Destination::Memory(Memory { - effective_address: EffectiveAddress { - base: Base::Register(r), - index: None, - scale: None, - displacement: None, - }, - size: self.arch.size(&type_.inner()?, &self.scope), - })) - } - _ => unreachable!("Can't get address of rvalue"), - } - } -} diff --git a/src/codegen/error.rs b/src/codegen/error.rs deleted file mode 100644 index 1179565..0000000 --- a/src/codegen/error.rs +++ /dev/null @@ -1,24 +0,0 @@ -use crate::{ - archs::ArchError, - parser::{ExprError, OpParseError}, - register::allocator::AllocatorError, - symbol_table::SymbolTableError, - types::TypeError, -}; -use thiserror::Error; - -#[derive(Error, Debug)] -pub enum CodeGenError { - #[error(transparent)] - OpParse(#[from] OpParseError), - #[error(transparent)] - Type(#[from] TypeError), - #[error(transparent)] - Allocator(#[from] AllocatorError), - #[error(transparent)] - SymbolTable(#[from] SymbolTableError), - #[error(transparent)] - Arch(#[from] ArchError), - #[error(transparent)] - Expr(#[from] ExprError), -} diff --git a/src/codegen/mod.rs b/src/codegen/mod.rs index 6d52613..862def7 100644 --- a/src/codegen/mod.rs +++ b/src/codegen/mod.rs @@ -1,13 +1,18 @@ -mod argument; -mod codegen; -mod error; -pub mod operands; -mod sethi_ullman; - -pub use argument::Argument; -pub use codegen::CodeGen; -pub use error::CodeGenError; -pub use operands::{ - Base, Destination, EffectiveAddress, Immediate, Memory, Offset, Register, Source, -}; -pub use sethi_ullman::SethiUllman; +pub mod amd64_asm; + +use crate::{parser::OpParseError, Context}; +use thiserror::Error; + +#[derive(Error, Debug)] +pub enum CodeGenError { + #[error(transparent)] + OpParse(#[from] OpParseError), +} + +pub trait Codegen<'a, 'ir> { + fn new(ctx: &'a Context<'ir>) -> Self + where + Self: Sized; + + fn compile(&mut self) -> Result, Box>; +} diff --git a/src/codegen/operands.rs b/src/codegen/operands.rs deleted file mode 100644 index 6111bb9..0000000 --- a/src/codegen/operands.rs +++ /dev/null @@ -1,233 +0,0 @@ -use crate::{ - archs::Arch, - parser::ExprLit, - register::{self, allocator::AllocatorError}, -}; - -#[derive(Debug, Clone, PartialEq, Default)] -pub struct Offset(pub isize); - -impl std::ops::Add for &Offset { - type Output = Offset; - - fn add(self, rhs: isize) -> Self::Output { - Offset(self.0 + rhs) - } -} - -impl std::ops::Sub for &Offset { - type Output = Offset; - - fn sub(self, rhs: isize) -> Self::Output { - Offset(self.0 - rhs) - } -} - -impl std::ops::Add<&Offset> for &Offset { - type Output = Offset; - - fn add(self, rhs: &Offset) -> Self::Output { - Offset(self.0 + rhs.0) - } -} - -impl std::ops::Sub<&Offset> for &Offset { - type Output = Offset; - - fn sub(self, rhs: &Offset) -> Self::Output { - Offset(self.0 - rhs.0) - } -} - -impl std::fmt::Display for Offset { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - if self.0 > 0 { - write!(f, " + {}", self.0) - } else if self.0 < 0 { - write!(f, " - {}", self.0.abs()) - } else { - write!(f, "") - } - } -} - -#[derive(Clone, Debug, PartialEq)] -pub enum Base { - Register(register::Register), - Label(String), -} - -#[derive(Clone, Debug, PartialEq)] -pub struct Register { - pub register: register::Register, - pub size: usize, -} - -#[derive(Clone, Debug, PartialEq)] -pub struct EffectiveAddress { - pub base: Base, - pub index: Option, - pub scale: Option, - pub displacement: Option, -} - -#[derive(Debug, Clone)] -pub enum Immediate { - Int(i64), - UInt(u64), - Label(String), -} - -#[derive(Clone, Debug, PartialEq)] -pub struct Memory { - pub effective_address: EffectiveAddress, - pub size: usize, -} - -impl From for Immediate { - fn from(value: ExprLit) -> Self { - match value { - ExprLit::UInt(uint) => Self::UInt(uint.inner), - ExprLit::Int(int) => Self::Int(int.inner), - ExprLit::String(label) => Self::Label(label), - ExprLit::Bool(bool) => Self::UInt(if bool { 1 } else { 0 }), - ExprLit::Null => Self::UInt(0), - } - } -} - -#[derive(Debug, Clone)] -pub enum Source { - Memory(Memory), - Register(Register), - Immediate(Immediate), -} - -impl Source { - pub fn size(&self) -> Option { - match self { - Self::Memory(memory) => Some(memory.size), - Self::Register(register) => Some(register.size), - Self::Immediate(_) => None, - } - } - - pub fn free(self, arch: &mut Arch) -> Result<(), AllocatorError> { - match self { - Self::Register(Register { register, .. }) => arch.free(register), - Self::Memory(Memory { - effective_address: - EffectiveAddress { - base: Base::Register(register), - .. - }, - .. - }) => arch.free(register), - _ => Ok(()), - } - } -} - -impl TryInto for Source { - type Error = (); - - fn try_into(self) -> Result { - match self { - Self::Memory(memory) => Ok(Destination::Memory(memory)), - Self::Register(register) => Ok(Destination::Register(register)), - // FIXME: return an error - Self::Immediate(_) => unreachable!(), - } - } -} - -#[derive(Clone, Debug, PartialEq)] -pub enum Destination { - Memory(Memory), - Register(Register), -} - -impl Destination { - pub fn with_size(mut self, size: usize) -> Self { - match &mut self { - Self::Memory(memory) => { - memory.size = size; - } - Self::Register(register) => { - register.size = size; - } - } - - self - } - - pub fn size(&self) -> usize { - match self { - Self::Memory(memory) => memory.size, - Self::Register(register) => register.size, - } - } -} - -impl Into for Destination { - fn into(self) -> Source { - match self { - Self::Memory(memory) => Source::Memory(memory), - Self::Register(register) => Source::Register(register), - } - } -} - -impl Into for Destination { - fn into(self) -> EffectiveAddress { - match self { - Destination::Register(register) => EffectiveAddress { - base: Base::Register(register.register), - index: None, - scale: None, - displacement: None, - }, - Destination::Memory(memory) => memory.effective_address, - } - } -} - -impl Into for Source { - fn into(self) -> EffectiveAddress { - match self { - Source::Register(register) => EffectiveAddress { - base: Base::Register(register.register), - index: None, - scale: None, - displacement: None, - }, - Source::Memory(memory) => memory.effective_address, - Source::Immediate(_) => unreachable!(), - } - } -} - -impl Into for Destination { - fn into(self) -> Base { - match self { - Destination::Register(register) => Base::Register(register.register), - Destination::Memory(memory) => memory.effective_address.base, - } - } -} - -impl PartialEq<&Source> for &Destination { - fn eq(&self, other: &&Source) -> bool { - match (self, other) { - (Destination::Memory(lhs), Source::Memory(rhs)) if lhs == rhs => true, - (Destination::Register(lhs), Source::Register(rhs)) if lhs == rhs => true, - _ => false, - } - } -} - -impl PartialEq<&Destination> for &Source { - fn eq(&self, other: &&Destination) -> bool { - other == self - } -} diff --git a/src/compile.rs b/src/compile.rs index e691d14..14e16ae 100644 --- a/src/compile.rs +++ b/src/compile.rs @@ -1,10 +1,10 @@ use crate::{ - archs::{Amd64, Architecture}, - codegen::CodeGen, + codegen::{amd64_asm::Amd64Asm, Codegen}, lexer::Lexer, - parser, - passes::{MacroExpansion, Pass, SymbolResolver, TypeChecker}, + lowering::Lowering, + parser, Context, }; +use bumpalo::Bump; use clap::Parser; use std::{ fs::File, @@ -46,16 +46,19 @@ pub fn compile(args: CompileArgs) -> Result<(), Box> { file.read_to_string(&mut source_code)?; let lexer = Lexer::new(source_code); - let (mut stmts, mut scope) = parser::Parser::new(lexer)?.into_parts()?; + let ast = parser::Parser::new(lexer)?.parse()?; + let allocator = Bump::new(); + let mut ctx = Context::new(&allocator); - MacroExpansion::new(args.macro_libs).run_pass(&mut stmts, &mut scope); - SymbolResolver::new(()).run_pass(&mut stmts, &mut scope)?; - TypeChecker::new(()).run_pass(&mut stmts, &mut scope)?; + Lowering::new(&mut ctx).lower(ast); + ctx.ty_problem.solve(); - dbg!(&stmts); - dbg!(&scope); + //MacroExpansion::new(args.macro_libs).run_pass(&mut stmts, &mut scope); + //SymbolResolver::new(()).run_pass(&mut stmts, &mut scope)?; + //TypeChecker::new(()).run_pass(&mut stmts, &mut scope)?; - let code = CodeGen::new(Box::new(Amd64::new()), scope).compile(stmts)?; + let codegen: &mut dyn Codegen = &mut Amd64Asm::new(&ctx); + let code = codegen.compile()?; if args.assembly_only { let asm_filename = args.file.with_extension("s"); diff --git a/src/ir/mod.rs b/src/ir/mod.rs new file mode 100644 index 0000000..dff52ed --- /dev/null +++ b/src/ir/mod.rs @@ -0,0 +1,119 @@ +mod ordered_map; +mod types; + +use crate::parser::{BinOp, UnOp}; +use bumpalo::Bump; + +pub use types::{IntTy, Ty, TyArray, UintTy}; + +#[derive(Debug, Copy, Clone, Default, PartialEq, Eq, Hash)] +pub struct Id { + pub global_id: usize, + pub node_id: usize, +} + +#[derive(Debug, Copy, Clone, PartialEq)] +pub struct Expr<'ir> { + pub ty: &'ir Ty<'ir>, + pub kind: ExprKind<'ir>, +} + +#[derive(Debug, Copy, Clone, PartialEq)] +pub enum ExprKind<'ir> { + // It's a reference only because it doesn't work without indirection + Binary(BinOp, &'ir Expr<'ir>, &'ir Expr<'ir>), + Unary(UnOp, &'ir Expr<'ir>), + Ident(Id), + Lit(ExprLit<'ir>), +} + +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +pub enum ExprLit<'ir> { + Int(i64), + UInt(u64), + Bool(bool), + String(&'ir str), + Null, +} + +#[derive(Debug, Copy, Clone, PartialEq)] +pub enum Stmt<'ir> { + Local(&'ir Variable<'ir>), + Item(Item<'ir>), + Expr(Expr<'ir>), + Return(Option>), +} + +#[derive(Debug, PartialEq)] +pub struct Block<'ir>(pub &'ir [Stmt<'ir>]); + +#[derive(Debug, PartialEq)] +pub struct Signature<'ir> { + pub params: &'ir [&'ir Ty<'ir>], + pub ret_ty: &'ir Ty<'ir>, +} + +#[derive(Debug, PartialEq)] +pub struct ItemFn<'ir> { + pub id: Id, + pub name: &'ir str, + pub signature: Signature<'ir>, + pub block: Block<'ir>, +} + +#[derive(Debug, PartialEq)] +pub struct Variable<'ir> { + pub id: Id, + pub name: &'ir str, + pub ty: &'ir Ty<'ir>, + pub initializer: Option>, +} + +#[derive(Debug, Copy, Clone, PartialEq)] +pub enum Item<'ir> { + Fn(&'ir ItemFn<'ir>), + Global(&'ir Variable<'ir>), +} + +#[derive(Debug, Clone, Copy)] +pub enum Node<'ir> { + Item(Item<'ir>), + Stmt(Stmt<'ir>), + Expr(Expr<'ir>), +} + +#[derive(Debug, Clone, Copy)] +pub struct Global<'ir>(pub &'ir [Node<'ir>]); + +#[derive(Debug)] +pub struct Ir<'ir> { + allocator: &'ir Bump, + globals: &'ir [Global<'ir>], +} + +impl<'ir> Ir<'ir> { + pub fn new(allocator: &'ir Bump) -> Self { + Ir { + allocator, + globals: &[], + } + } + + pub fn set_globals(&mut self, globals: &'ir [Global<'ir>]) { + self.globals = globals; + } + + pub fn iter_items(&self) -> impl Iterator> { + self.globals + .iter() + .map(|global| match global.0[0] { + Node::Item(item) => item, + _ => unreachable!(), + }) + .into_iter() + } + + pub fn get_node(&self, id: Id) -> &'ir Node<'ir> { + &self.globals[id.global_id].0[id.node_id] + } +} diff --git a/src/ir/ordered_map.rs b/src/ir/ordered_map.rs new file mode 100644 index 0000000..cc1debc --- /dev/null +++ b/src/ir/ordered_map.rs @@ -0,0 +1,11 @@ +pub trait OrderedMap { + fn get(&self, k: &K) -> Option<&V>; +} + +impl<'ir, K: Eq, V> OrderedMap for &'ir [(K, V)] { + fn get(&self, k: &K) -> Option<&V> { + self.iter() + .find(|(key, _)| key == k) + .map(|(_, value)| value) + } +} diff --git a/src/ir/types.rs b/src/ir/types.rs new file mode 100644 index 0000000..9e45799 --- /dev/null +++ b/src/ir/types.rs @@ -0,0 +1,166 @@ +use crate::ty_problem::Id; + +#[derive(Debug, PartialEq)] +pub struct TyArray<'ir> { + pub ty: &'ir Ty<'ir>, + pub len: usize, +} + +#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Hash)] +pub enum IntTy { + I8, + I16, + I32, + I64, + Isize, +} + +impl IntTy { + fn size(&self) -> Option { + Some(match self { + Self::I8 => 1, + Self::I16 => 2, + Self::I32 => 4, + Self::I64 => 8, + Self::Isize => return None, + }) + } +} + +impl std::fmt::Display for IntTy { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::I8 => write!(f, "i8"), + Self::I16 => write!(f, "i16"), + Self::I32 => write!(f, "i32"), + Self::I64 => write!(f, "i64"), + Self::Isize => write!(f, "isize"), + } + } +} + +#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Hash)] +pub enum UintTy { + U8, + U16, + U32, + U64, + Usize, +} + +impl UintTy { + fn size(&self) -> Option { + Some(match self { + Self::U8 => 1, + Self::U16 => 2, + Self::U32 => 4, + Self::U64 => 8, + Self::Usize => return None, + }) + } + + pub fn to_signed(self) -> IntTy { + match self { + Self::U8 => IntTy::I8, + Self::U16 => IntTy::I16, + Self::U32 => IntTy::I32, + Self::U64 => IntTy::I64, + Self::Usize => IntTy::Isize, + } + } +} + +impl std::fmt::Display for UintTy { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::U8 => write!(f, "u8"), + Self::U16 => write!(f, "u16"), + Self::U32 => write!(f, "u32"), + Self::U64 => write!(f, "u64"), + Self::Usize => write!(f, "usize"), + } + } +} + +#[derive(Debug, PartialEq)] +pub enum Ty<'ir> { + Void, + Null, + Bool, + Int(IntTy), + UInt(UintTy), + Ptr(&'ir Ty<'ir>), + Array(TyArray<'ir>), + Fn(&'ir [&'ir Ty<'ir>], &'ir Ty<'ir>), + Struct(&'ir [(&'ir str, &'ir Ty<'ir>)]), + Infer(Id), +} + +impl Ty<'_> { + pub fn size(&self, bitness: usize) -> usize { + match self { + Self::Void => 0, + Self::Null | Self::Bool => 1, + Self::Int(int) => match int { + IntTy::I8 => 1, + IntTy::I16 => 2, + IntTy::I32 => 4, + IntTy::I64 => 8, + IntTy::Isize => bitness / 8, + }, + Self::UInt(uint) => match uint { + UintTy::U8 => 1, + UintTy::U16 => 2, + UintTy::U32 => 4, + UintTy::U64 => 8, + UintTy::Usize => bitness / 8, + }, + Self::Ptr(_) | Self::Fn(_, _) => bitness / 8, + Self::Array(ty) => ty.ty.size(bitness) * ty.len, + Self::Struct(ty) => ty.iter().map(|(_, ty)| ty.size(bitness)).sum(), + Self::Infer(_) => unreachable!(), + } + } +} + +impl std::fmt::Display for Ty<'_> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Int(int) => int.fmt(f), + Self::UInt(uint) => uint.fmt(f), + Self::Bool => write!(f, "bool"), + Self::Void => write!(f, "void"), + Self::Ptr(type_) => write!(f, "*{type_}"), + Self::Array(array) => write!(f, "{}[{}]", array.ty, array.len), + Self::Fn(params, return_type) => write!( + f, + "fn ({}) -> {return_type}", + params + .iter() + .map(|type_| type_.to_string()) + .collect::() + ), + Self::Null => write!(f, "NULL"), + Self::Struct(_) => write!(f, "owo"), + Self::Infer(id) => write!(f, "infer({id:?})"), + } + } +} + +impl<'ir> Ty<'ir> { + pub fn ptr(&self) -> bool { + matches!(self, Self::Ptr(..)) + } + + pub fn arr(&self) -> bool { + matches!(self, Self::Array(..)) + } + + pub fn signed(&self) -> bool { + matches!(self, Self::Int(..)) + } + + pub fn int(&self) -> bool { + matches!(self, Self::UInt(_) | Self::Int(_)) + } +} diff --git a/src/lib.rs b/src/lib.rs index 8286058..fd440b2 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,12 +1,34 @@ -pub mod archs; pub mod codegen; pub mod compile; +pub mod ir; pub mod lexer; +pub mod lowering; pub mod macros; pub mod parser; pub mod passes; -pub mod register; -pub mod scope; -pub mod symbol_table; -pub mod type_table; -pub mod types; +pub mod ty_problem; + +use bumpalo::Bump; +use ir::{Ir, Ty}; +use ty_problem::TyProblem; + +#[derive(Debug)] +pub struct Context<'ir> { + pub allocator: &'ir Bump, + pub ir: Ir<'ir>, + pub ty_problem: TyProblem<'ir>, +} + +impl<'ir> Context<'ir> { + pub fn new(allocator: &'ir Bump) -> Self { + Self { + allocator, + ir: Ir::new(allocator), + ty_problem: TyProblem::new(), + } + } + + pub fn resolve_ty(&self, ty: &'ir Ty<'ir>) -> &'ir Ty<'ir> { + self.ty_problem.resolve_ty(self, ty) + } +} diff --git a/src/lowering/mod.rs b/src/lowering/mod.rs new file mode 100644 index 0000000..443351d --- /dev/null +++ b/src/lowering/mod.rs @@ -0,0 +1,406 @@ +mod scopes; + +use crate::{ + ir::{self, Id, Stmt}, + parser::{self, BinOp, Item, UnOp, Variable}, + ty_problem, Context, +}; +use scopes::Scopes; +use std::collections::HashMap; + +#[derive(Debug)] +pub struct Lowering<'a, 'ir> { + ctx: &'a mut Context<'ir>, + types: HashMap>, + scopes: Scopes<'ir>, + nodes: Vec>, + globals: Vec>, + nodes_map: HashMap>, + id: Id, + ret_ty: Option<&'ir ir::Ty<'ir>>, +} + +impl<'a, 'ir> Lowering<'a, 'ir> { + pub fn new(ctx: &'a mut Context<'ir>) -> Self { + Self { + ctx, + types: HashMap::new(), + scopes: Scopes::new(), + nodes: Vec::new(), + globals: Vec::new(), + nodes_map: HashMap::new(), + id: Id::default(), + ret_ty: None, + } + } + + pub fn lower(mut self, ast: Vec) { + self.scopes.enter(); + + ast.into_iter().for_each(|item| { + self.lower_item(item); + }); + + let globals = self.ctx.allocator.alloc_slice_copy(&self.globals); + self.ctx.ir.set_globals(globals); + } + + pub fn lower_item(&mut self, item: Item) -> Option> { + match item { + Item::Struct(item) => { + let mut fields = Vec::new(); + + for (field, ty) in item.fields { + fields.push((&*self.ctx.allocator.alloc_str(&field), self.lower_ty(ty))); + } + + let fields = self.ctx.allocator.alloc_slice_copy(&fields); + let ty = self.ctx.allocator.alloc(ir::Ty::Struct(fields)); + + self.types.insert(parser::Ty::Ident(item.name), ty); + + None + } + Item::Fn(item) => { + self.id.node_id = 1; + self.scopes.enter(); + + let ret_ty = self.lower_ty(item.ret_ty); + self.ret_ty = Some(ret_ty); + + let stmts = if let Some(block) = item.block { + block + .0 + .into_iter() + .map(|stmt| self.lower_stmt(stmt)) + .collect() + } else { + Vec::new() + }; + + self.ret_ty = None; + self.scopes.leave(); + + let params: Vec<&'ir ir::Ty<'ir>> = item + .params + .into_iter() + .map(|(_, ty)| self.lower_ty(ty)) + .collect(); + let signature = ir::Signature { + params: self.ctx.allocator.alloc_slice_copy(¶ms), + ret_ty, + }; + + self.nodes.insert( + 0, + ir::Node::Item(ir::Item::Fn(self.ctx.allocator.alloc(ir::ItemFn { + id: Id { + global_id: self.id.global_id, + node_id: 0, + }, + name: self.ctx.allocator.alloc_str(&item.name), + signature, + block: ir::Block(self.ctx.allocator.alloc_slice_copy(&stmts)), + }))), + ); + + self.globals + .push(ir::Global(self.ctx.allocator.alloc_slice_copy(&self.nodes))); + self.nodes.clear(); + self.id.global_id += 1; + self.id.node_id = 0; + + None + } + Item::Global(var) => { + let name = var.name.clone(); + let ir_var = self.lower_var_decl(var); + let node = ir::Node::Item(ir::Item::Global(ir_var)); + + self.scopes.insert_symbol(name, self.id); + self.nodes_map.insert(self.id, node); + + self.globals + .push(ir::Global(self.ctx.allocator.alloc_slice_copy(&[node]))); + self.id.global_id += 1; + + Some(ir::Item::Global(ir_var)) + } + } + } + + fn lower_var_decl(&mut self, variable: Variable) -> &'ir ir::Variable<'ir> { + let ty = self.lower_ty(variable.ty); + + let initializer = if let Some(expr) = variable.value { + let expr = self.lower_expr(expr); + let let_ty_var_id = self.tys_ty_var_id(ty); + let expr_ty_var_id = self.tys_ty_var_id(expr.ty); + + self.ctx.ty_problem.eq(let_ty_var_id, expr_ty_var_id); + + Some(expr) + } else { + None + }; + + let ir_variable = self.ctx.allocator.alloc(ir::Variable { + id: self.id, + name: self.ctx.allocator.alloc_str(&variable.name), + ty, + initializer, + }); + + ir_variable + } + + fn lower_stmt(&mut self, stmt: parser::Stmt) -> ir::Stmt<'ir> { + match stmt { + parser::Stmt::Local(var) => { + let name = var.name.clone(); + let ir_var = self.lower_var_decl(var); + let node = ir::Node::Stmt(ir::Stmt::Local(ir_var)); + + self.scopes.insert_symbol(name, self.id); + self.nodes_map.insert(self.id, node); + + self.nodes.push(node); + self.id.node_id += 1; + + ir::Stmt::Local(ir_var) + } + parser::Stmt::Item(item) => ir::Stmt::Item(self.lower_item(item).unwrap()), + parser::Stmt::Expr(expr) => ir::Stmt::Expr(self.lower_expr(expr)), + parser::Stmt::Return(stmt) => { + let expr = stmt.expr.map(|expr| { + let expr = self.lower_expr(expr); + let expr_ty_var_id = self.tys_ty_var_id(expr.ty); + let ret_ty_var = self.tys_ty_var_id(self.ret_ty.unwrap()); + + self.ctx.ty_problem.eq(expr_ty_var_id, ret_ty_var); + + expr + }); + + ir::Stmt::Return(expr) + } + _ => todo!(), + } + } + + fn lower_expr(&mut self, expr: parser::Expr) -> ir::Expr<'ir> { + match expr { + parser::Expr::Binary(parser::ExprBinary { + op, + ref left, + ref right, + }) => { + // TODO: remove clones + let lhs = self.lower_expr(*left.clone()); + let rhs = self.lower_expr(*right.clone()); + + let lhs_ty_var_id = self.tys_ty_var_id(lhs.ty); + let rhs_ty_var_id = self.tys_ty_var_id(rhs.ty); + + let ty = match op { + BinOp::Add => { + let ty = self.expr_ty(&expr); + let expr = self.tys_ty_var_id(ty); + + self.ctx + .ty_problem + .bin_add(expr, lhs_ty_var_id, rhs_ty_var_id); + + ty + } + BinOp::Sub => { + let ty = self.expr_ty(&expr); + let expr = self.tys_ty_var_id(ty); + + self.ctx + .ty_problem + .bin_sub(expr, lhs_ty_var_id, rhs_ty_var_id); + + ty + } + _ => { + self.ctx.ty_problem.eq(lhs_ty_var_id, rhs_ty_var_id); + + lhs.ty + } + }; + + ir::Expr { + ty, + kind: ir::ExprKind::Binary( + op, + self.ctx.allocator.alloc(lhs), + self.ctx.allocator.alloc(rhs), + ), + } + } + parser::Expr::Ident(parser::ExprIdent(ref ident)) => { + let id = self.scopes.get_symbol(ident).unwrap(); + let ty = self.expr_ty(&expr); + + ir::Expr { + ty, + kind: ir::ExprKind::Ident(id), + } + } + parser::Expr::Lit(ref lit) => { + let ty = self.expr_ty(&expr); + let kind = match lit { + parser::ExprLit::Int(lit) => ir::ExprKind::Lit(ir::ExprLit::Int(*lit)), + parser::ExprLit::UInt(lit) => ir::ExprKind::Lit(ir::ExprLit::UInt(*lit)), + parser::ExprLit::Bool(lit) => ir::ExprKind::Lit(ir::ExprLit::Bool(*lit)), + parser::ExprLit::String(lit) => { + ir::ExprKind::Lit(ir::ExprLit::String(self.ctx.allocator.alloc_str(&lit))) + } + parser::ExprLit::Null => ir::ExprKind::Lit(ir::ExprLit::Null), + }; + + ir::Expr { ty, kind } + } + parser::Expr::Unary(parser::ExprUnary { op, expr }) => { + let ir_expr = self.lower_expr(*expr); + let ty = match op { + UnOp::Address => self.ctx.allocator.alloc(ir::Ty::Ptr(ir_expr.ty)), + UnOp::Deref => { + let deref = &*self + .ctx + .allocator + .alloc(ir::Ty::Infer(self.ctx.ty_problem.new_infer_ty_var())); + let reference = self + .ctx + .ty_problem + .new_typed_ty_var(self.ctx.allocator.alloc(ir::Ty::Ptr(deref))); + let expr_ty_var_id = self.tys_ty_var_id(ir_expr.ty); + + self.ctx.ty_problem.eq(expr_ty_var_id, reference); + + deref + } + _ => ir_expr.ty, + }; + + ir::Expr { + ty, + kind: ir::ExprKind::Unary(op, self.ctx.allocator.alloc(ir_expr)), + } + } + _ => todo!(), + } + } + + fn lower_ty(&mut self, ty: parser::Ty) -> &'ir ir::Ty<'ir> { + match self.types.get(&ty) { + Some(ty) => *ty, + None => { + let ir_ty = match &ty { + parser::Ty::Null => self.ctx.allocator.alloc(ir::Ty::Null), + parser::Ty::Void => self.ctx.allocator.alloc(ir::Ty::Void), + parser::Ty::Bool => self.ctx.allocator.alloc(ir::Ty::Bool), + parser::Ty::Int(ty) => match ty { + parser::IntTy::I8 => self.ctx.allocator.alloc(ir::Ty::Int(ir::IntTy::I8)), + parser::IntTy::I16 => self.ctx.allocator.alloc(ir::Ty::Int(ir::IntTy::I16)), + parser::IntTy::I32 => self.ctx.allocator.alloc(ir::Ty::Int(ir::IntTy::I32)), + parser::IntTy::I64 => self.ctx.allocator.alloc(ir::Ty::Int(ir::IntTy::I64)), + parser::IntTy::Isize => { + self.ctx.allocator.alloc(ir::Ty::Int(ir::IntTy::Isize)) + } + }, + parser::Ty::UInt(ty) => match ty { + parser::UintTy::U8 => { + self.ctx.allocator.alloc(ir::Ty::UInt(ir::UintTy::U8)) + } + parser::UintTy::U16 => { + self.ctx.allocator.alloc(ir::Ty::UInt(ir::UintTy::U16)) + } + parser::UintTy::U32 => { + self.ctx.allocator.alloc(ir::Ty::UInt(ir::UintTy::U32)) + } + parser::UintTy::U64 => { + self.ctx.allocator.alloc(ir::Ty::UInt(ir::UintTy::U64)) + } + parser::UintTy::Usize => { + self.ctx.allocator.alloc(ir::Ty::UInt(ir::UintTy::Usize)) + } + }, + parser::Ty::Ptr(ref ty) => self + .ctx + .allocator + .alloc(ir::Ty::Ptr(self.lower_ty(*ty.clone()))), + parser::Ty::Array(parser::TyArray { ref ty, len }) => { + self.ctx.allocator.alloc(ir::Ty::Array(ir::TyArray { + len: *len, + ty: self.lower_ty(*ty.clone()), + })) + } + parser::Ty::Fn(ref params, ref ret_ty) => { + let mut alloced_params = Vec::new(); + + for ty in params { + alloced_params.push(self.lower_ty(ty.clone())); + } + + let params = &*self.ctx.allocator.alloc_slice_copy(&alloced_params); + + self.ctx + .allocator + .alloc(ir::Ty::Fn(params, self.lower_ty(*ret_ty.clone()))) + } + parser::Ty::Ident(ident) => { + return self.scopes.get_type(ident).unwrap(); + } + parser::Ty::Infer => self + .ctx + .allocator + .alloc(ir::Ty::Infer(self.ctx.ty_problem.new_infer_ty_var())), + }; + + if !matches!(ty, parser::Ty::Infer) { + self.types.insert(ty, ir_ty); + } + + ir_ty + } + } + } + + fn expr_ty(&mut self, expr: &parser::Expr) -> &'ir ir::Ty<'ir> { + match expr { + parser::Expr::Binary(_) => self + .ctx + .allocator + .alloc(ir::Ty::Infer(self.ctx.ty_problem.new_infer_ty_var())), + parser::Expr::Lit(lit) => match lit { + parser::ExprLit::Bool(_) => &ir::Ty::Bool, + parser::ExprLit::String(_) => &ir::Ty::Ptr(&ir::Ty::UInt(ir::UintTy::U8)), + _ => self + .ctx + .allocator + .alloc(ir::Ty::Infer(self.ctx.ty_problem.new_infer_ty_var())), + }, + parser::Expr::Ident(parser::ExprIdent(ident)) => { + let id = self.scopes.get_symbol(ident).unwrap(); + + match self.nodes_map.get(&id).unwrap() { + ir::Node::Stmt(stmt) => match stmt { + Stmt::Local(stmt) => stmt.ty, + _ => unreachable!(), + }, + _ => panic!("nono"), + } + } + _ => todo!(), + } + } + + fn tys_ty_var_id(&mut self, ty: &'ir ir::Ty<'ir>) -> ty_problem::Id { + match ty { + ir::Ty::Infer(id) => *id, + ty => self.ctx.ty_problem.new_typed_ty_var(ty), + } + } +} diff --git a/src/lowering/scopes.rs b/src/lowering/scopes.rs new file mode 100644 index 0000000..5670367 --- /dev/null +++ b/src/lowering/scopes.rs @@ -0,0 +1,67 @@ +use crate::ir::{Id, Ty}; +use std::collections::HashMap; + +#[derive(Debug)] +pub struct Scope<'ir> { + pub type_table: HashMap>, + pub symbol_table: HashMap, +} + +impl<'ir> Scope<'ir> { + pub fn new() -> Self { + Self { + type_table: HashMap::new(), + symbol_table: HashMap::new(), + } + } +} + +#[derive(Debug)] +pub struct Scopes<'ir>(Vec>); + +impl<'ir> Scopes<'ir> { + pub fn new() -> Self { + Self(Vec::new()) + } + + pub fn enter(&mut self) { + self.0.push(Scope::new()); + } + + pub fn leave(&mut self) { + self.0.pop(); + } + + pub fn find(&self, f: T) -> Option + where + T: Fn(&Scope<'ir>) -> Option, + { + for scope in &self.0 { + if let item @ Some(_) = f(scope) { + return item; + } + } + + None + } + + pub fn insert_type(&mut self, name: String, ty: &'ir Ty<'ir>) { + self.0.last_mut().unwrap().type_table.insert(name, ty); + } + + pub fn get_type(&self, name: &str) -> Option<&'ir Ty<'ir>> { + self.find(|scope| scope.type_table.get(name).map(|&ty| ty)) + } + + pub fn insert_symbol(&mut self, name: String, id: Id) { + self.0.last_mut().unwrap().symbol_table.insert(name, id); + } + + pub fn get_symbol(&self, name: &str) -> Option { + self.find(|scope| scope.symbol_table.get(name).map(|&id| id)) + } + + pub fn is_global(&self) -> bool { + self.0.len() < 1 + } +} diff --git a/src/parser/error.rs b/src/parser/error.rs index 431e167..d497be7 100644 --- a/src/parser/error.rs +++ b/src/parser/error.rs @@ -1,25 +1,15 @@ -use super::{ExprError, IntLitReprError, OpParseError}; -use crate::{ - lexer::{LexerError, Token}, - symbol_table::SymbolTableError, - types::{Type, TypeError}, -}; +use super::{OpParseError, Ty}; +use crate::lexer::{LexerError, Token}; use thiserror::Error; #[derive(Error, Debug)] pub enum ParserError { - #[error(transparent)] - Expr(#[from] ExprError), #[error(transparent)] Lexer(#[from] LexerError), #[error(transparent)] - Type(#[from] TypeError), + Type(#[from] TyError), #[error(transparent)] Operator(#[from] OpParseError), - #[error(transparent)] - Int(#[from] IntLitReprError), - #[error(transparent)] - SymbolTable(#[from] SymbolTableError), #[error("Expected token {0}, got {1}")] UnexpectedToken(Token, Token), #[error("Expected {0}")] @@ -30,19 +20,26 @@ pub enum ParserError { Prefix(Token), #[error("Failed to parse infix token {0}")] Infix(Token), - #[error("Call to undeclared function {0}")] - UndeclaredFunction(String), - #[error("Function has signature ({}), got called with ({})", - .0 - .iter() - .map(|type_| type_.to_string()) - .collect::>() - .join(", "), - .1 - .iter() - .map(|type_| type_.to_string()) - .collect::>() - .join(", ") - )] - FunctionArguments(Vec, Vec), +} + +#[derive(Error, Debug, PartialEq)] +pub enum TyError { + #[error("Operation between {0} and {1} are not allowed")] + Promotion(Ty, Ty), + #[error("Ident {0} not found")] + IdentNotFound(String), + #[error("Can't assign {0} to {1}")] + Assignment(Ty, Ty), + #[error("Can't cast {0} into {1}")] + Cast(Ty, Ty), + #[error("Expected return value of type {1}, got {0} instead")] + Return(Ty, Ty), + #[error("Variable can't be of type void")] + VoidVariable, + #[error("Type '{0}' doens't exits")] + Nonexistent(String), + #[error("Type {0} is not pointer")] + Deref(Ty), + #[error("Mismatched types expected {0}, found {1}")] + Mismatched(Ty, Ty), } diff --git a/src/parser/expr/error.rs b/src/parser/expr/error.rs deleted file mode 100644 index 7e9c3f0..0000000 --- a/src/parser/expr/error.rs +++ /dev/null @@ -1,10 +0,0 @@ -use crate::{symbol_table::SymbolTableError, types::TypeError}; -use thiserror::Error; - -#[derive(Error, Debug)] -pub enum ExprError { - #[error(transparent)] - Type(#[from] TypeError), - #[error(transparent)] - SymbolTable(#[from] SymbolTableError), -} diff --git a/src/parser/expr/expr.rs b/src/parser/expr/expr.rs index 18691ca..e86c5ea 100644 --- a/src/parser/expr/expr.rs +++ b/src/parser/expr/expr.rs @@ -1,16 +1,11 @@ -use super::{int_repr::UIntLitRepr, ExprError, IntLitRepr}; use crate::{ lexer::Token, - parser::op::{BinOp, UnOp}, - scope::Scope, - type_table, - types::{IntType, Type, TypeArray, TypeError}, + parser::{ + op::{BinOp, UnOp}, + Ty, + }, }; -pub trait Expression { - fn type_(&self, scope: &Scope) -> Result; -} - #[derive(Debug, Clone, PartialEq)] pub enum Expr { Binary(ExprBinary), @@ -27,60 +22,6 @@ pub enum Expr { MacroCall(MacroCall), } -impl Expression for Expr { - fn type_(&self, scope: &Scope) -> Result { - match self { - Self::Binary(expr) => expr.type_(scope), - Self::Unary(expr) => expr.type_(scope), - Self::Lit(literal) => literal.type_(scope), - Self::Ident(ident) => ident.type_(scope), - Self::Cast(cast) => cast.type_(scope), - Self::Struct(expr_struct) => expr_struct.type_(scope), - Self::Array(expr) => expr.type_(scope), - Self::StructAccess(expr_struct_access) => expr_struct_access.type_(scope), - Self::StructMethod(expr) => expr.type_(scope), - Self::ArrayAccess(expr) => expr.type_(scope), - Self::FunctionCall(expr) => expr.type_(scope), - Self::MacroCall(_) => unreachable!("Macro calls should've already been expanded"), - } - } -} - -impl Expr { - pub fn lvalue(&self) -> bool { - match self { - Self::StructAccess(_) - | Self::ArrayAccess(_) - | Self::Ident(_) - | Self::Unary(ExprUnary { - op: UnOp::Deref, .. - }) => true, - _ => false, - } - } - - pub fn int_lit_only(expr: &Expr) -> bool { - match expr { - Expr::Binary(expr) => { - Expr::int_lit_only(expr.left.as_ref()) && Expr::int_lit_only(expr.right.as_ref()) - } - Expr::Unary(expr) => Expr::int_lit_only(expr.expr.as_ref()), - Expr::Cast(expr) => { - if expr.type_.int() { - Expr::int_lit_only(expr.expr.as_ref()) - } else { - false - } - } - Expr::Lit(expr) => match expr { - ExprLit::Int(_) | ExprLit::UInt(_) => true, - _ => false, - }, - _ => false, - } - } -} - #[derive(Debug, Clone, PartialEq)] pub struct ExprBinary { pub op: BinOp, @@ -88,81 +29,18 @@ pub struct ExprBinary { pub right: Box, } -impl Expression for ExprBinary { - fn type_(&self, scope: &Scope) -> Result { - match &self.op { - BinOp::Sub => match ( - (&self.left, self.left.type_(scope)?), - (&self.right, self.right.type_(scope)?), - ) { - ((_, Type::Ptr(_)), (_, Type::Ptr(_))) => Ok(Type::Int(IntType::Isize)), - ((_, lhs), (_, rhs)) => Ok(Type::common_type(lhs, rhs)), - }, - BinOp::Add - | BinOp::Mul - | BinOp::Div - | BinOp::Assign - | BinOp::BitwiseAnd - | BinOp::BitwiseOr - | BinOp::Shl - | BinOp::Shr => Ok(Type::common_type( - self.left.type_(scope)?, - self.right.type_(scope)?, - )), - BinOp::LessThan - | BinOp::GreaterThan - | BinOp::LessEqual - | BinOp::GreaterEqual - | BinOp::Equal - | BinOp::NotEqual - | BinOp::LogicalAnd - | BinOp::LogicalOr => Ok(Type::Bool), - } - } -} - #[derive(Debug, Clone, PartialEq)] pub enum ExprLit { - Int(IntLitRepr), - UInt(UIntLitRepr), + Int(i64), + UInt(u64), Bool(bool), String(String), Null, } -impl Expression for ExprLit { - fn type_(&self, _: &Scope) -> Result { - match self { - ExprLit::Int(int) => Ok(int.type_()), - ExprLit::UInt(uint) => Ok(uint.type_()), - ExprLit::Bool(_) => Ok(Type::Bool), - ExprLit::String(_) => Ok(Type::Ptr(Box::new(Type::Int(IntType::I8)))), - ExprLit::Null => Ok(Type::Null), - } - } -} - -impl ExprLit { - pub fn signed(&self) -> bool { - match self { - ExprLit::Int(_) => true, - _ => false, - } - } -} - #[derive(Debug, Clone, PartialEq)] pub struct ExprIdent(pub String); -impl Expression for ExprIdent { - fn type_(&self, scope: &Scope) -> Result { - Ok(scope - .find_symbol(&self.0) - .ok_or(TypeError::IdentNotFound(self.0.to_owned()))? - .type_()) - } -} - #[derive(Debug, Clone, PartialEq)] pub struct ExprStruct { pub name: String, @@ -172,39 +50,12 @@ pub struct ExprStruct { #[derive(Debug, Clone, PartialEq)] pub struct ExprArray(pub Vec); -impl Expression for ExprArray { - fn type_(&self, scope: &Scope) -> Result { - Ok(Type::Array(TypeArray { - type_: Box::new(self.0.get(0).unwrap().type_(scope)?), - length: self.0.len(), - })) - } -} - #[derive(Debug, Clone, PartialEq)] pub struct ExprStructAccess { pub expr: Box, pub field: String, } -impl Expression for ExprStructAccess { - fn type_(&self, scope: &Scope) -> Result { - match self.expr.type_(scope)? { - Type::Custom(struct_name) => { - match scope - .find_type(&struct_name) - .ok_or(TypeError::Nonexistent(struct_name.to_owned()))? - { - type_table::Type::Struct(type_struct) => { - Ok(type_struct.get_field_type(&self.field).unwrap().to_owned()) - } - } - } - _ => unreachable!(), - } - } -} - #[derive(Debug, Clone, PartialEq)] pub struct ExprStructMethod { pub expr: Box, @@ -212,41 +63,12 @@ pub struct ExprStructMethod { pub arguments: Vec, } -impl Expression for ExprStructMethod { - fn type_(&self, scope: &Scope) -> Result { - match self.expr.type_(scope)? { - Type::Custom(struct_name) => { - match scope - .find_type(&struct_name) - .ok_or(TypeError::Nonexistent(struct_name.to_owned()))? - { - type_table::Type::Struct(type_struct) => Ok(type_struct - .find_method(&self.method) - .unwrap() - .return_type - .to_owned()), - } - } - _ => unreachable!(), - } - } -} - #[derive(Debug, Clone, PartialEq)] pub struct ExprFunctionCall { pub expr: Box, pub arguments: Vec, } -impl Expression for ExprFunctionCall { - fn type_(&self, scope: &Scope) -> Result { - match self.expr.type_(scope)? { - Type::Fn(_, type_) => Ok(*type_), - _ => unreachable!(), - } - } -} - #[derive(Debug, Clone, PartialEq)] pub struct MacroCall { pub name: String, @@ -259,47 +81,14 @@ pub struct ExprArrayAccess { pub index: Box, } -impl Expression for ExprArrayAccess { - fn type_(&self, scope: &Scope) -> Result { - Ok(self.expr.type_(scope)?.inner()?) - } -} - -impl Expression for ExprStruct { - fn type_(&self, _: &Scope) -> Result { - Ok(Type::Custom(self.name.clone())) - } -} - #[derive(Debug, Clone, PartialEq)] pub struct ExprUnary { pub op: UnOp, pub expr: Box, } -impl Expression for ExprUnary { - fn type_(&self, scope: &Scope) -> Result { - Ok(match &self.op { - UnOp::Negative => match self.expr.type_(scope)? { - Type::UInt(uint) => Type::Int(uint.to_signed()), - type_ => type_, - }, - UnOp::LogicalNot => Type::Bool, - UnOp::Address => Type::Ptr(Box::new(self.expr.type_(scope)?)), - UnOp::Deref => self.expr.type_(scope)?.inner()?, - UnOp::BitwiseNot => self.expr.type_(scope)?, - }) - } -} - #[derive(Debug, Clone, PartialEq)] pub struct ExprCast { pub expr: Box, - pub type_: Type, -} - -impl Expression for ExprCast { - fn type_(&self, _: &Scope) -> Result { - Ok(self.type_.to_owned()) - } + pub ty: Ty, } diff --git a/src/parser/expr/int_repr.rs b/src/parser/expr/int_repr.rs deleted file mode 100644 index 8347da4..0000000 --- a/src/parser/expr/int_repr.rs +++ /dev/null @@ -1,113 +0,0 @@ -use crate::types::{IntType, Type, UintType}; -use std::num::ParseIntError; -use thiserror::Error; - -const I8_MIN: i64 = i8::MIN as i64; -const I8_MAX: i64 = i8::MAX as i64; -const I16_MIN: i64 = i16::MIN as i64; -const I16_MAX: i64 = i16::MAX as i64; -const I32_MIN: i64 = i32::MIN as i64; -const I32_MAX: i64 = i32::MAX as i64; -const I64_MIN: i64 = i64::MIN as i64; -const I64_MAX: i64 = i64::MAX as i64; - -const U8_MIN: u64 = u8::MIN as u64; -const U8_MAX: u64 = u8::MAX as u64; -const U16_MIN: u64 = u16::MIN as u64; -const U16_MAX: u64 = u16::MAX as u64; -const U32_MIN: u64 = u32::MIN as u64; -const U32_MAX: u64 = u32::MAX as u64; -const U64_MIN: u64 = u64::MIN as u64; -const U64_MAX: u64 = u64::MAX as u64; - -#[derive(Debug, Clone, PartialEq, Default)] -pub struct IntLitRepr { - pub inner: i64, -} - -impl IntLitRepr { - pub fn new(value: i64) -> Self { - Self { inner: value } - } - - pub fn type_(&self) -> Type { - match self.inner { - I8_MIN..=I8_MAX => Type::Int(IntType::I8), - I16_MIN..=I16_MAX => Type::Int(IntType::I16), - I32_MIN..=I32_MAX => Type::Int(IntType::I32), - I64_MIN..=I64_MAX => Type::Int(IntType::I64), - } - } - - pub fn zero_except_n_bytes(&mut self, n: usize) { - self.inner &= 0 << n * 8; - } - - pub fn negate(&mut self) { - self.inner = -self.inner; - } -} - -impl ToString for IntLitRepr { - fn to_string(&self) -> String { - self.inner.to_string() - } -} - -impl TryFrom for IntLitRepr { - type Error = IntLitReprError; - - fn try_from(value: UIntLitRepr) -> Result { - Ok(Self { - inner: value.inner.try_into().unwrap(), - }) - } -} - -#[derive(Debug, Clone, PartialEq, Default)] -pub struct UIntLitRepr { - pub inner: u64, -} - -impl UIntLitRepr { - pub fn new(value: u64) -> Self { - Self { inner: value } - } - - pub fn type_(&self) -> Type { - match self.inner { - U8_MIN..=U8_MAX => Type::UInt(UintType::U8), - U16_MIN..=U16_MAX => Type::UInt(UintType::U16), - U32_MIN..=U32_MAX => Type::UInt(UintType::U32), - U64_MIN..=U64_MAX => Type::UInt(UintType::U64), - } - } - - pub fn zero_except_n_bytes(&mut self, n: usize) { - self.inner &= 0 << n * 8; - } -} - -impl ToString for UIntLitRepr { - fn to_string(&self) -> String { - self.inner.to_string() - } -} - -impl TryFrom<&str> for UIntLitRepr { - type Error = IntLitReprError; - - fn try_from(value: &str) -> Result { - Ok(Self { - inner: value.parse::()?, - }) - } -} - -#[derive(Error, Debug)] -pub enum IntLitReprError { - #[error("{0} bits integers are not supported")] - TooLarge(usize), - #[error(transparent)] - ParseInt(#[from] ParseIntError), -} diff --git a/src/parser/expr/mod.rs b/src/parser/expr/mod.rs index e3f9e6b..a446347 100644 --- a/src/parser/expr/mod.rs +++ b/src/parser/expr/mod.rs @@ -1,10 +1,6 @@ -mod error; mod expr; -mod int_repr; -pub use error::ExprError; pub use expr::{ Expr, ExprArray, ExprArrayAccess, ExprBinary, ExprCast, ExprFunctionCall, ExprIdent, ExprLit, - ExprStruct, ExprStructAccess, ExprStructMethod, ExprUnary, Expression, MacroCall, + ExprStruct, ExprStructAccess, ExprStructMethod, ExprUnary, MacroCall, }; -pub use int_repr::{IntLitRepr, IntLitReprError, UIntLitRepr}; diff --git a/src/parser/item.rs b/src/parser/item.rs new file mode 100644 index 0000000..b126f00 --- /dev/null +++ b/src/parser/item.rs @@ -0,0 +1,22 @@ +use super::{Block, Ty, Variable}; + +#[derive(Debug, Clone, PartialEq)] +pub enum Item { + Global(Variable), + Fn(ItemFn), + Struct(ItemStruct), +} + +#[derive(Debug, Clone, PartialEq)] +pub struct ItemFn { + pub ret_ty: Ty, + pub name: String, + pub params: Vec<(String, Ty)>, + pub block: Option, +} + +#[derive(Debug, Clone, PartialEq)] +pub struct ItemStruct { + pub name: String, + pub fields: Vec<(String, Ty)>, +} diff --git a/src/parser/mod.rs b/src/parser/mod.rs index 0ac61b7..92fa918 100644 --- a/src/parser/mod.rs +++ b/src/parser/mod.rs @@ -1,13 +1,28 @@ mod error; -pub mod expr; +mod item; mod op; mod parser; mod precedence; mod stmt; +mod types; + +pub mod expr; -pub use error::ParserError; +pub use error::{ParserError, TyError}; pub use expr::*; +pub use item::{Item, ItemFn, ItemStruct}; pub use op::{BinOp, BitwiseOp, CmpOp, OpParseError, UnOp}; pub use parser::Parser; pub use precedence::Precedence; -pub use stmt::{Block, Stmt, StmtFor, StmtFunction, StmtIf, StmtReturn, StmtVarDecl, StmtWhile}; +pub use stmt::{Stmt, StmtFor, StmtIf, StmtReturn, StmtWhile}; +pub use types::{IntTy, Ty, TyArray, UintTy}; + +#[derive(Debug, Clone, PartialEq)] +pub struct Variable { + pub ty: Ty, + pub name: String, + pub value: Option, +} + +#[derive(Debug, Clone, PartialEq)] +pub struct Block(pub Vec); diff --git a/src/parser/op.rs b/src/parser/op.rs index 4df0201..07b1cb4 100644 --- a/src/parser/op.rs +++ b/src/parser/op.rs @@ -13,7 +13,7 @@ pub enum OpParseError { Bitwise(BinOp), } -#[derive(Debug, Clone, PartialEq)] +#[derive(Debug, Clone, PartialEq, Copy)] pub enum BinOp { Add, Sub, @@ -61,7 +61,7 @@ impl TryFrom<&Token> for BinOp { } } -#[derive(Debug, Clone, PartialEq)] +#[derive(Debug, Clone, Copy, PartialEq)] pub enum UnOp { LogicalNot, Negative, diff --git a/src/parser/parser.rs b/src/parser/parser.rs index 24133d1..eb39b16 100644 --- a/src/parser/parser.rs +++ b/src/parser/parser.rs @@ -1,17 +1,15 @@ use super::{ expr::{ExprBinary, ExprLit, ExprUnary}, + item::Item, precedence::Precedence, stmt::{StmtFor, StmtIf, StmtReturn, StmtWhile}, BinOp, Block, Expr, ExprArray, ExprArrayAccess, ExprCast, ExprIdent, ExprStruct, - ExprStructMethod, MacroCall, ParserError, Stmt, StmtFunction, StmtVarDecl, UIntLitRepr, UnOp, + ExprStructMethod, IntTy, ItemFn, ItemStruct, MacroCall, ParserError, Stmt, Ty, TyArray, UintTy, + UnOp, Variable, }; use crate::{ lexer::{LexerError, Token}, parser::{ExprFunctionCall, ExprStructAccess}, - scope::{Scope, ScopeKind}, - symbol_table::{Symbol, SymbolFunction}, - type_table::{TypeStruct, TypeStructMethod}, - types::{IntType, Type, TypeArray, TypeError, UintType}, }; use std::collections::HashMap; @@ -22,23 +20,16 @@ pub struct Parser>> { lexer: T, cur_token: Option, peek_token: Option, - scope: Scope, - global_stms: Vec, prefix_fns: HashMap>, infix_fns: HashMap>, } impl>> Parser { pub fn new(mut lexer: T) -> Result { - let mut scope = Scope::new(); - scope.enter_new(ScopeKind::Global); - Ok(Self { cur_token: lexer.next().transpose()?, peek_token: lexer.next().transpose()?, lexer, - scope, - global_stms: Vec::new(), prefix_fns: HashMap::from([ (Token::Ident(Default::default()), Self::ident as PrefixFn), (Token::String(Default::default()), Self::string_lit), @@ -108,29 +99,20 @@ impl>> Parser { } } - pub fn into_parts(mut self) -> Result<(Vec, Scope), ParserError> { - Ok((self.parse()?, self.scope)) - } - - pub fn parse(&mut self) -> Result, ParserError> { - let mut stmts = Vec::new(); + pub fn parse(&mut self) -> Result, ParserError> { + let mut items = Vec::new(); while let Some(token) = &self.cur_token { - match token { + let item = match token { Token::Struct => self.parse_struct()?, - Token::Let => stmts.push(self.var_decl()?), - Token::Fn => { - if let Some(stmt) = self.function(true)? { - stmts.push(stmt) - } - } + Token::Let => self.global()?, + Token::Fn => self.function(true)?, _ => unreachable!(), - } + }; + items.push(item); } - stmts.extend_from_slice(&self.global_stms); - - Ok(stmts) + Ok(items) } pub fn expr(&mut self, precedence: Precedence) -> Result { @@ -163,7 +145,7 @@ impl>> Parser { left } - fn parse_struct(&mut self) -> Result<(), ParserError> { + fn parse_struct(&mut self) -> Result { self.expect(&Token::Struct)?; let name = match self @@ -180,41 +162,10 @@ impl>> Parser { self.expect(&Token::LBrace)?; let mut fields = Vec::new(); - let mut methods = Vec::new(); while !self.cur_token_is(&Token::RBrace) { if self.cur_token_is(&Token::Fn) { - self.expect(&Token::Fn)?; - - let method_name = match self.next_token()? { - Some(Token::Ident(ident)) => ident, - _ => todo!(), - }; - self.expect(&Token::LParen)?; - let mut params = self.params(Token::Comma, Token::RParen)?; - params.insert( - 0, - ( - "this".to_owned(), - Type::Ptr(Box::new(Type::Custom(name.clone()))), - ), - ); - self.expect(&Token::Arrow)?; - - let type_ = self.parse_type()?; - let block = self.compound_statement(ScopeKind::Function(type_.clone()))?; - - methods.push(TypeStructMethod { - return_type: type_.clone(), - name: method_name.clone(), - params: params.clone(), - }); - self.global_stms.push(Stmt::Function(StmtFunction { - return_type: type_, - name: format!("{name}__{method_name}"), - params, - block, - })); + // Handle struct methods here } else { let name = match self.next_token()? { Some(Token::Ident(ident)) => ident, @@ -237,65 +188,51 @@ impl>> Parser { self.expect(&Token::RBrace)?; - self.scope - .type_table_mut() - .define(crate::type_table::Type::Struct(TypeStruct { - name, - fields, - methods, - })); - - Ok(()) + Ok(Item::Struct(ItemStruct { name, fields })) } - fn stmt(&mut self) -> Result, ParserError> { + fn stmt(&mut self) -> Result { match self.cur_token.as_ref().unwrap() { - Token::Return => Ok(Some(self.parse_return()?)), - Token::If => Ok(Some(self.if_stmt()?)), - Token::While => Ok(Some(self.while_stmt()?)), - Token::For => Ok(Some(self.for_stmt()?)), - Token::Let => Ok(Some(self.var_decl()?)), + Token::Return => self.parse_return(), + Token::If => self.if_stmt(), + Token::While => self.while_stmt(), + Token::For => self.for_stmt(), + Token::Let => self.local(), Token::Continue => { self.expect(&Token::Continue)?; self.expect(&Token::Semicolon)?; - Ok(Some(Stmt::Continue)) + Ok(Stmt::Continue) } Token::Break => { self.expect(&Token::Break)?; self.expect(&Token::Semicolon)?; - Ok(Some(Stmt::Break)) + Ok(Stmt::Break) } - Token::Fn => self.function(true), + Token::Fn => Ok(Stmt::Item(self.function(true)?)), _ => { let expr = Stmt::Expr(self.expr(Precedence::default())?); self.expect(&Token::Semicolon)?; - Ok(Some(expr)) + Ok(expr) } } } - fn compound_statement(&mut self, scope_kind: ScopeKind) -> Result { + fn compound_statement(&mut self) -> Result { let mut stmts = Vec::new(); - self.scope.enter_new(scope_kind); self.expect(&Token::LBrace)?; while !self.cur_token_is(&Token::RBrace) { - if let Some(stmt) = self.stmt()? { - stmts.push(stmt); - } + stmts.push(self.stmt()?); } self.expect(&Token::RBrace)?; - Ok(Block { - statements: stmts, - scope: self.scope.leave(), - }) + Ok(Block(stmts)) } // This function is used only by macro expansion @@ -303,15 +240,13 @@ impl>> Parser { let mut stmts = Vec::new(); while self.cur_token.is_some() { - if let Some(stmt) = self.stmt()? { - stmts.push(stmt); - } + stmts.push(self.stmt()?); } Ok(stmts) } - fn parse_type(&mut self) -> Result { + fn parse_type(&mut self) -> Result { let mut n = 0; while self.cur_token_is(&Token::Asterisk) { self.expect(&Token::Asterisk)?; @@ -319,19 +254,19 @@ impl>> Parser { } let mut base = match self.next_token()?.unwrap() { - Token::U8 => Ok(Type::UInt(UintType::U8)), - Token::U16 => Ok(Type::UInt(UintType::U16)), - Token::U32 => Ok(Type::UInt(UintType::U32)), - Token::U64 => Ok(Type::UInt(UintType::U64)), - Token::I8 => Ok(Type::Int(IntType::I8)), - Token::I16 => Ok(Type::Int(IntType::I16)), - Token::I32 => Ok(Type::Int(IntType::I32)), - Token::I64 => Ok(Type::Int(IntType::I64)), - Token::Usize => Ok(Type::UInt(UintType::Usize)), - Token::Isize => Ok(Type::Int(IntType::Isize)), - Token::Bool => Ok(Type::Bool), - Token::Void => Ok(Type::Void), - Token::Ident(ident) => Ok(Type::Custom(ident)), + Token::U8 => Ok(Ty::UInt(UintTy::U8)), + Token::U16 => Ok(Ty::UInt(UintTy::U16)), + Token::U32 => Ok(Ty::UInt(UintTy::U32)), + Token::U64 => Ok(Ty::UInt(UintTy::U64)), + Token::I8 => Ok(Ty::Int(IntTy::I8)), + Token::I16 => Ok(Ty::Int(IntTy::I16)), + Token::I32 => Ok(Ty::Int(IntTy::I32)), + Token::I64 => Ok(Ty::Int(IntTy::I64)), + Token::Usize => Ok(Ty::UInt(UintTy::Usize)), + Token::Isize => Ok(Ty::Int(IntTy::Isize)), + Token::Bool => Ok(Ty::Bool), + Token::Void => Ok(Ty::Void), + Token::Ident(ident) => Ok(Ty::Ident(ident)), Token::Fn => { self.expect(&Token::LParen)?; @@ -348,13 +283,13 @@ impl>> Parser { self.expect(&Token::RParen)?; self.expect(&Token::Arrow)?; - Ok(Type::Fn(params, Box::new(self.parse_type()?))) + Ok(Ty::Fn(params, Box::new(self.parse_type()?))) } token => Err(ParserError::ParseType(token)), }?; while n > 0 { - base = Type::Ptr(Box::new(base)); + base = Ty::Ptr(Box::new(base)); n -= 1; } @@ -379,11 +314,11 @@ impl>> Parser { self.expect(&Token::If)?; let condition = self.expr(Precedence::default())?; - let consequence = self.compound_statement(ScopeKind::Local)?; + let consequence = self.compound_statement()?; let alternative = if self.cur_token_is(&Token::Else) { self.expect(&Token::Else)?; - Some(self.compound_statement(ScopeKind::Local)?) + Some(self.compound_statement()?) } else { None }; @@ -399,7 +334,7 @@ impl>> Parser { self.expect(&Token::While)?; let condition = self.expr(Precedence::default())?; - let block = self.compound_statement(ScopeKind::Loop)?; + let block = self.compound_statement()?; Ok(Stmt::While(StmtWhile { condition, block })) } @@ -411,7 +346,7 @@ impl>> Parser { None } else { let stmt = if self.cur_token_is(&Token::Let) { - self.var_decl()? + self.local()? } else { Stmt::Expr(self.expr(Precedence::default())?) }; @@ -432,7 +367,7 @@ impl>> Parser { Some(self.expr(Precedence::default())?) }; - let block = self.compound_statement(ScopeKind::Loop)?; + let block = self.compound_statement()?; Ok(Stmt::For(StmtFor { initializer: initializer.map(|initializer| Box::new(initializer)), @@ -442,7 +377,7 @@ impl>> Parser { })) } - fn array_type(&mut self, type_: &mut Type) -> Result<(), ParserError> { + fn array_type(&mut self, type_: &mut Ty) -> Result<(), ParserError> { if self.cur_token_is(&Token::LBracket) { self.expect(&Token::LBracket)?; @@ -451,9 +386,9 @@ impl>> Parser { let length: usize = str::parse(&int).unwrap(); self.expect(&Token::RBracket)?; - *type_ = Type::Array(TypeArray { - type_: Box::new(type_.clone()), - length, + *type_ = Ty::Array(TyArray { + ty: Box::new(type_.clone()), + len: length, }); } token => panic!("Expected integer, got {token}"), @@ -463,7 +398,7 @@ impl>> Parser { Ok(()) } - fn var_decl(&mut self) -> Result { + fn local(&mut self) -> Result { self.expect(&Token::Let)?; let name = match self.next_token()?.unwrap() { @@ -472,14 +407,16 @@ impl>> Parser { return Err(ParserError::ParseType(token)); } }; - self.expect(&Token::Colon)?; + let ty = if self.cur_token_is(&Token::Colon) { + self.expect(&Token::Colon)?; - let mut type_ = self.parse_type()?; - if let Type::Void = type_ { - return Err(ParserError::Type(TypeError::VoidVariable)); - } + let mut ty = self.parse_type()?; + self.array_type(&mut ty)?; - self.array_type(&mut type_)?; + ty + } else { + Ty::Infer + }; let expr = if self.cur_token_is(&Token::Assign) { self.expect(&Token::Assign)?; @@ -491,10 +428,45 @@ impl>> Parser { self.expect(&Token::Semicolon)?; - Ok(Stmt::VarDecl(StmtVarDecl::new(type_, name, expr))) + Ok(Stmt::Local(Variable { + name, + ty, + value: expr, + })) } - fn function(&mut self, func_definition: bool) -> Result, ParserError> { + fn global(&mut self) -> Result { + self.expect(&Token::Let)?; + + let name = match self.next_token()?.unwrap() { + Token::Ident(ident) => ident, + token => { + return Err(ParserError::ParseType(token)); + } + }; + self.expect(&Token::Colon)?; + + let mut ty = self.parse_type()?; + self.array_type(&mut ty)?; + + let expr = if self.cur_token_is(&Token::Assign) { + self.expect(&Token::Assign)?; + + Some(self.expr(Precedence::default())?) + } else { + None + }; + + self.expect(&Token::Semicolon)?; + + Ok(Item::Global(Variable { + name, + ty, + value: expr, + })) + } + + fn function(&mut self, func_definition: bool) -> Result { self.expect(&Token::Fn)?; let name = match self.next_token()?.unwrap() { @@ -511,7 +483,7 @@ impl>> Parser { let type_ = self.parse_type()?; let block = if self.cur_token_is(&Token::LBrace) { - Some(self.compound_statement(ScopeKind::Function(type_.clone()))?) + Some(self.compound_statement()?) } else { None }; @@ -520,28 +492,19 @@ impl>> Parser { panic!("Function definition is not supported here"); } - if let Some(block) = block { - Ok(Some(Stmt::Function(StmtFunction { - return_type: type_, - name, - params, - block, - }))) - } else { - self.scope.symbol_table_mut().push( - name.clone(), - Symbol::Function(SymbolFunction { - return_type: type_.clone(), - parameters: params.clone().into_iter().map(|(_, type_)| type_).collect(), - }), - )?; + if block.is_none() { self.expect(&Token::Semicolon)?; - - Ok(None) } + + Ok(Item::Fn(ItemFn { + ret_ty: type_, + name, + params, + block, + })) } - fn params(&mut self, delim: Token, end: Token) -> Result, ParserError> { + fn params(&mut self, delim: Token, end: Token) -> Result, ParserError> { let mut params = Vec::new(); while !self.cur_token_is(&end) { @@ -619,9 +582,7 @@ impl>> Parser { fn int_lit(&mut self) -> Result { match self.next_token()? { - Some(Token::Integer(num_str)) => Ok(Expr::Lit(ExprLit::UInt( - UIntLitRepr::try_from(&num_str[..]).map_err(|e| ParserError::Int(e))?, - ))), + Some(Token::Integer(num_str)) => Ok(Expr::Lit(ExprLit::UInt(num_str.parse().unwrap()))), Some(_) | None => unreachable!(), } } @@ -754,7 +715,7 @@ impl>> Parser { Ok(Expr::Cast(ExprCast { expr: Box::new(expr), - type_: self.parse_type()?, + ty: self.parse_type()?, })) } @@ -817,11 +778,9 @@ mod test { use crate::{ lexer::Lexer, parser::{ - BinOp, Expr, ExprBinary, ExprCast, ExprIdent, ExprLit, ExprUnary, ParserError, Stmt, - StmtVarDecl, UIntLitRepr, UnOp, + BinOp, Expr, ExprBinary, ExprCast, ExprIdent, ExprLit, ExprUnary, IntTy, ParserError, + Stmt, Ty, UintTy, UnOp, Variable, }, - scope::ScopeKind, - types::{IntType, Type, UintType}, }; #[test] @@ -837,18 +796,18 @@ mod test { op: BinOp::Add, left: Box::new(Expr::Binary(ExprBinary { op: BinOp::Mul, - left: Box::new(Expr::Lit(ExprLit::UInt(UIntLitRepr::new(1)))), - right: Box::new(Expr::Lit(ExprLit::UInt(UIntLitRepr::new(2)))), + left: Box::new(Expr::Lit(ExprLit::UInt(1))), + right: Box::new(Expr::Lit(ExprLit::UInt(2))), })), right: Box::new(Expr::Binary(ExprBinary { op: BinOp::Div, - left: Box::new(Expr::Lit(ExprLit::UInt(UIntLitRepr::new(3)))), + left: Box::new(Expr::Lit(ExprLit::UInt(3))), right: Box::new(Expr::Binary(ExprBinary { op: BinOp::Add, - left: Box::new(Expr::Lit(ExprLit::UInt(UIntLitRepr::new(4)))), + left: Box::new(Expr::Lit(ExprLit::UInt(4))), right: Box::new(Expr::Cast(ExprCast { - type_: Type::UInt(UintType::U8), - expr: Box::new(Expr::Lit(ExprLit::UInt(UIntLitRepr::new(1)))), + ty: Ty::UInt(UintTy::U8), + expr: Box::new(Expr::Lit(ExprLit::UInt(1))), })), })), })), @@ -862,24 +821,24 @@ mod test { } ", vec![ - Stmt::VarDecl(StmtVarDecl::new( - Type::UInt(UintType::U8), - "foo".to_owned(), - None, - )), + Stmt::Local(Variable { + name: "foo".to_owned(), + ty: Ty::UInt(UintTy::U8), + value: None, + }), Stmt::Expr(Expr::Binary(ExprBinary { op: BinOp::Assign, left: Box::new(Expr::Ident(ExprIdent("foo".to_owned()))), right: Box::new(Expr::Binary(ExprBinary { op: BinOp::Add, left: Box::new(Expr::Cast(ExprCast { - type_: Type::UInt(UintType::U8), + ty: Ty::UInt(UintTy::U8), expr: Box::new(Expr::Unary(ExprUnary { op: UnOp::Negative, - expr: Box::new(Expr::Lit(ExprLit::UInt(UIntLitRepr::new(1)))), + expr: Box::new(Expr::Lit(ExprLit::UInt(1))), })), })), - right: Box::new(Expr::Lit(ExprLit::UInt(UIntLitRepr::new(5)))), + right: Box::new(Expr::Lit(ExprLit::UInt(5))), })), })), ], @@ -893,29 +852,29 @@ mod test { } ", vec![ - Stmt::VarDecl(StmtVarDecl::new( - Type::UInt(UintType::U8), - "foo".to_owned(), - None, - )), - Stmt::VarDecl(StmtVarDecl::new( - Type::Int(IntType::I8), - "bar".to_owned(), - None, - )), + Stmt::Local(Variable { + name: "foo".to_owned(), + ty: Ty::UInt(UintTy::U8), + value: None, + }), + Stmt::Local(Variable { + name: "bar".to_owned(), + ty: Ty::Int(IntTy::I8), + value: None, + }), Stmt::Expr(Expr::Binary(ExprBinary { op: BinOp::Assign, left: Box::new(Expr::Ident(ExprIdent("bar".to_owned()))), right: Box::new(Expr::Binary(ExprBinary { op: BinOp::Add, left: Box::new(Expr::Cast(ExprCast { - type_: Type::Int(IntType::I8), + ty: Ty::Int(IntTy::I8), expr: Box::new(Expr::Ident(ExprIdent("foo".to_owned()))), })), right: Box::new(Expr::Binary(ExprBinary { op: BinOp::Div, - left: Box::new(Expr::Lit(ExprLit::UInt(UIntLitRepr::new(5)))), - right: Box::new(Expr::Lit(ExprLit::UInt(UIntLitRepr::new(10)))), + left: Box::new(Expr::Lit(ExprLit::UInt(5))), + right: Box::new(Expr::Lit(ExprLit::UInt(10))), })), })), })), @@ -930,13 +889,13 @@ mod test { vec![Stmt::Expr(Expr::Binary(ExprBinary { op: BinOp::Add, left: Box::new(Expr::Cast(ExprCast { - type_: Type::Int(IntType::I8), - expr: Box::new(Expr::Lit(ExprLit::UInt(UIntLitRepr::new(1)))), + ty: Ty::Int(IntTy::I8), + expr: Box::new(Expr::Lit(ExprLit::UInt(1))), })), right: Box::new(Expr::Binary(ExprBinary { op: BinOp::Div, - left: Box::new(Expr::Lit(ExprLit::UInt(UIntLitRepr::new(2)))), - right: Box::new(Expr::Lit(ExprLit::UInt(UIntLitRepr::new(3)))), + left: Box::new(Expr::Lit(ExprLit::UInt(2))), + right: Box::new(Expr::Lit(ExprLit::UInt(3))), })), }))], ), @@ -950,23 +909,23 @@ mod test { } ", vec![ - Stmt::VarDecl(StmtVarDecl::new( - Type::UInt(UintType::U8), - "a".to_owned(), - None, - )), - Stmt::VarDecl(StmtVarDecl::new( - Type::UInt(UintType::U8), - "b".to_owned(), - None, - )), + Stmt::Local(Variable { + name: "a".to_owned(), + ty: Ty::UInt(UintTy::U8), + value: None, + }), + Stmt::Local(Variable { + name: "b".to_owned(), + ty: Ty::UInt(UintTy::U8), + value: None, + }), Stmt::Expr(Expr::Binary(ExprBinary { op: BinOp::Assign, left: Box::new(Expr::Ident(ExprIdent("a".to_owned()))), right: Box::new(Expr::Binary(ExprBinary { op: BinOp::Assign, left: Box::new(Expr::Ident(ExprIdent("b".to_owned()))), - right: Box::new(Expr::Lit(ExprLit::UInt(UIntLitRepr::new(69)))), + right: Box::new(Expr::Lit(ExprLit::UInt(69))), })), })), ], @@ -975,10 +934,10 @@ mod test { for (input, expected) in tests { let mut parser = Parser::new(Lexer::new(input.to_string())).unwrap(); - let ast = parser.compound_statement(ScopeKind::Global).unwrap(); + let ast = parser.compound_statement().unwrap(); assert_eq!( - &ast.statements, &expected, + &ast.0, &expected, "expected: {:?}, got: {:?}", expected, ast ); diff --git a/src/parser/stmt.rs b/src/parser/stmt.rs index 42de5d3..1b2e978 100644 --- a/src/parser/stmt.rs +++ b/src/parser/stmt.rs @@ -1,11 +1,10 @@ -use super::Expr; -use crate::{scope::ScopeImpl, types::Type}; +use super::{item::Item, Block, Expr, Variable}; #[derive(Debug, Clone, PartialEq)] pub enum Stmt { - VarDecl(StmtVarDecl), + Local(Variable), + Item(Item), Expr(Expr), - Function(StmtFunction), Return(StmtReturn), If(StmtIf), While(StmtWhile), @@ -14,25 +13,6 @@ pub enum Stmt { Break, } -#[derive(Debug, Clone, PartialEq)] -pub struct Block { - pub statements: Vec, - pub scope: ScopeImpl, -} - -#[derive(Debug, Clone, PartialEq)] -pub struct StmtVarDecl { - pub type_: Type, - pub name: String, - pub value: Option, -} - -impl StmtVarDecl { - pub fn new(type_: Type, name: String, value: Option) -> Self { - Self { type_, name, value } - } -} - #[derive(Debug, Clone, PartialEq)] pub struct StmtReturn { pub expr: Option, @@ -58,11 +38,3 @@ pub struct StmtFor { pub increment: Option, pub block: Block, } - -#[derive(Debug, Clone, PartialEq)] -pub struct StmtFunction { - pub return_type: Type, - pub name: String, - pub params: Vec<(String, Type)>, - pub block: Block, -} diff --git a/src/types/types.rs b/src/parser/types.rs similarity index 62% rename from src/types/types.rs rename to src/parser/types.rs index 0529035..8dabb45 100644 --- a/src/types/types.rs +++ b/src/parser/types.rs @@ -1,13 +1,13 @@ -use super::TypeError; +use super::error::TyError; #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Hash)] -pub struct TypeArray { - pub type_: Box, - pub length: usize, +pub struct TyArray { + pub ty: Box, + pub len: usize, } #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Hash)] -pub enum IntType { +pub enum IntTy { I8, I16, I32, @@ -15,7 +15,7 @@ pub enum IntType { Isize, } -impl IntType { +impl IntTy { fn size(&self) -> Option { Some(match self { Self::I8 => 1, @@ -27,7 +27,7 @@ impl IntType { } } -impl std::fmt::Display for IntType { +impl std::fmt::Display for IntTy { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { Self::I8 => write!(f, "i8"), @@ -40,7 +40,7 @@ impl std::fmt::Display for IntType { } #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Hash)] -pub enum UintType { +pub enum UintTy { U8, U16, U32, @@ -48,7 +48,7 @@ pub enum UintType { Usize, } -impl UintType { +impl UintTy { fn size(&self) -> Option { Some(match self { Self::U8 => 1, @@ -59,18 +59,18 @@ impl UintType { }) } - pub fn to_signed(self) -> IntType { + pub fn to_signed(self) -> IntTy { match self { - Self::U8 => IntType::I8, - Self::U16 => IntType::I16, - Self::U32 => IntType::I32, - Self::U64 => IntType::I64, - Self::Usize => IntType::Isize, + Self::U8 => IntTy::I8, + Self::U16 => IntTy::I16, + Self::U32 => IntTy::I32, + Self::U64 => IntTy::I64, + Self::Usize => IntTy::Isize, } } } -impl std::fmt::Display for UintType { +impl std::fmt::Display for UintTy { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { Self::U8 => write!(f, "u8"), @@ -83,19 +83,20 @@ impl std::fmt::Display for UintType { } #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Hash)] -pub enum Type { - Int(IntType), - UInt(UintType), - Bool, - Void, - Custom(String), - Ptr(Box), - Array(TypeArray), - Fn(Vec, Box), +pub enum Ty { Null, + Void, + Bool, + Int(IntTy), + UInt(UintTy), + Ident(String), + Ptr(Box), + Array(TyArray), + Fn(Vec, Box), + Infer, } -impl std::fmt::Display for Type { +impl std::fmt::Display for Ty { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { Self::Int(int) => int.fmt(f), @@ -103,8 +104,8 @@ impl std::fmt::Display for Type { Self::Bool => write!(f, "bool"), Self::Void => write!(f, "void"), Self::Ptr(type_) => write!(f, "*{type_}"), - Self::Custom(name) => write!(f, "{name}"), - Self::Array(array) => write!(f, "{}[{}]", array.type_, array.length), + Self::Ident(name) => write!(f, "{name}"), + Self::Array(array) => write!(f, "{}[{}]", array.ty, array.len), Self::Fn(params, return_type) => write!( f, "fn ({}) -> {return_type}", @@ -114,11 +115,12 @@ impl std::fmt::Display for Type { .collect::() ), Self::Null => write!(f, "NULL"), + Self::Infer => unreachable!(), } } } -impl Type { +impl Ty { pub fn ptr(&self) -> bool { matches!(self, Self::Ptr(..)) } @@ -132,73 +134,71 @@ impl Type { } pub fn int(&self) -> bool { - matches!(self, Type::UInt(_) | Type::Int(_)) + matches!(self, Ty::UInt(_) | Ty::Int(_)) } - pub fn cast(from: Self, to: Self) -> Result { + pub fn cast(from: Self, to: Self) -> Result { match (from, to) { (from, to) if from.int() && to.int() => Ok(to), - (from, to) if from == Self::Bool && to.int() || from.int() && to == Type::Bool => { - Ok(to) - } + (from, to) if from == Self::Bool && to.int() || from.int() && to == Ty::Bool => Ok(to), (from, to) if from.arr() && to.ptr() && from.inner().unwrap() == to.inner().unwrap() => { Ok(to) } - (Type::Array(_), Type::Ptr(pointee)) if pointee.as_ref() == &Type::Void => { - Ok(Type::Ptr(pointee)) + (Ty::Array(_), Ty::Ptr(pointee)) if pointee.as_ref() == &Ty::Void => { + Ok(Ty::Ptr(pointee)) } (from, to) if from.ptr() && to.ptr() => Ok(to), (from, to) if from.ptr() && to.int() => Ok(to), - (from, to) => Err(TypeError::Cast(from, to)), + (from, to) => Err(TyError::Cast(from, to)), } } pub fn size(&self) -> Option { match self { - Type::Void => Some(0), - Type::Bool => Some(1), - Type::Int(int) => int.size(), - Type::UInt(uint) => uint.size(), + Ty::Void => Some(0), + Ty::Bool => Some(1), + Ty::Int(int) => int.size(), + Ty::UInt(uint) => uint.size(), _ => None, } } - pub fn inner(&self) -> Result { + pub fn inner(&self) -> Result { match self { Self::Ptr(type_) => Ok(type_.as_ref().to_owned()), - Self::Array(array) => Ok(*array.type_.clone()), - type_ => Err(TypeError::Deref(type_.clone())), + Self::Array(array) => Ok(*array.ty.clone()), + type_ => Err(TyError::Deref(type_.clone())), } } - pub fn common_type(lhs: Type, rhs: Type) -> Type { + pub fn common_type(lhs: Ty, rhs: Ty) -> Ty { match (lhs, rhs) { (lhs, rhs) if lhs == rhs => lhs, - (type_ @ Type::Ptr(_), int) | (int, type_ @ Type::Ptr(_)) if int.int() => type_, - (type_ @ Type::Ptr(_), Type::Null) | (Type::Null, type_ @ Type::Ptr(_)) => type_, - (Type::UInt(lhs), Type::UInt(rhs)) => { + (type_ @ Ty::Ptr(_), int) | (int, type_ @ Ty::Ptr(_)) if int.int() => type_, + (type_ @ Ty::Ptr(_), Ty::Null) | (Ty::Null, type_ @ Ty::Ptr(_)) => type_, + (Ty::UInt(lhs), Ty::UInt(rhs)) => { if lhs > rhs { - Type::UInt(lhs) + Ty::UInt(lhs) } else { - Type::UInt(rhs) + Ty::UInt(rhs) } } - (Type::Int(lhs), Type::Int(rhs)) => { + (Ty::Int(lhs), Ty::Int(rhs)) => { if lhs > rhs { - Type::Int(lhs) + Ty::Int(lhs) } else { - Type::Int(rhs) + Ty::Int(rhs) } } - (Type::UInt(uint), Type::Int(int)) | (Type::Int(int), Type::UInt(uint)) => { + (Ty::UInt(uint), Ty::Int(int)) | (Ty::Int(int), Ty::UInt(uint)) => { let uint_int = uint.to_signed(); if uint_int <= int { - Type::Int(int) + Ty::Int(int) } else { - Type::Int(uint_int) + Ty::Int(uint_int) } } (lhs, rhs) => unreachable!("Failed to get common type for {lhs} and {rhs}"), diff --git a/src/passes/mod.rs b/src/passes/mod.rs index b10778f..4cdfa03 100644 --- a/src/passes/mod.rs +++ b/src/passes/mod.rs @@ -1,9 +1,9 @@ -mod macro_expansion; -mod pass; -mod symbol_resolver; -mod type_checker; - -pub use macro_expansion::MacroExpansion; -pub use pass::Pass; -pub use symbol_resolver::SymbolResolver; -pub use type_checker::TypeChecker; +//mod macro_expansion; +//mod pass; +//mod symbol_resolver; +//mod type_checker; +// +//pub use macro_expansion::MacroExpansion; +//pub use pass::Pass; +//pub use symbol_resolver::SymbolResolver; +//pub use type_checker::TypeChecker; diff --git a/src/register/allocator/error.rs b/src/register/allocator/error.rs deleted file mode 100644 index 0ad21d0..0000000 --- a/src/register/allocator/error.rs +++ /dev/null @@ -1,12 +0,0 @@ -use crate::register::Register; -use thiserror::Error; - -#[derive(Error, Debug)] -pub enum AllocatorError { - #[error("Register was double freed")] - DoubleFree, - #[error("Ran out of registers, whoops!")] - RanOutOfRegisters, - #[error("Register {} is already in use", .0.from_size(8))] - AlreadyInUse(Register), -} diff --git a/src/register/allocator/mod.rs b/src/register/allocator/mod.rs deleted file mode 100644 index e2c6944..0000000 --- a/src/register/allocator/mod.rs +++ /dev/null @@ -1,5 +0,0 @@ -mod allocator; -mod error; - -pub use allocator::RegisterAllocator; -pub use error::AllocatorError; diff --git a/src/register/mod.rs b/src/register/mod.rs deleted file mode 100644 index e4f2dd2..0000000 --- a/src/register/mod.rs +++ /dev/null @@ -1,4 +0,0 @@ -pub mod allocator; -mod register; - -pub use register::Register; diff --git a/src/register/register.rs b/src/register/register.rs deleted file mode 100644 index d3dc7ca..0000000 --- a/src/register/register.rs +++ /dev/null @@ -1,68 +0,0 @@ -use crate::codegen::{ - operands::{self, Source}, - Destination, -}; - -#[derive(Debug, Clone, Copy, PartialEq)] -pub struct Register { - byte: &'static str, - word: &'static str, - dword: &'static str, - qword: &'static str, -} - -impl Register { - pub const fn new( - byte: &'static str, - word: &'static str, - dword: &'static str, - qword: &'static str, - ) -> Self { - return Self { - byte, - word, - dword, - qword, - }; - } - - pub fn byte(&self) -> &'static str { - self.byte - } - - pub fn word(&self) -> &'static str { - self.word - } - - pub fn dword(&self) -> &'static str { - self.dword - } - - pub fn qword(&self) -> &'static str { - self.qword - } - - pub fn from_size(&self, size: usize) -> &'static str { - match size { - 1 => self.byte(), - 2 => self.word(), - 4 => self.dword(), - 8 => self.qword(), - _ => unreachable!(), - } - } - - pub fn dest(&self, size: usize) -> Destination { - Destination::Register(operands::Register { - register: *self, - size, - }) - } - - pub fn source(&self, size: usize) -> Source { - Source::Register(operands::Register { - register: *self, - size, - }) - } -} diff --git a/src/scope.rs b/src/scope.rs deleted file mode 100644 index 6abdf8c..0000000 --- a/src/scope.rs +++ /dev/null @@ -1,121 +0,0 @@ -use crate::{ - symbol_table::{Symbol, SymbolTable}, - type_table::{self as tt, TypeTable}, - types::Type, -}; - -#[derive(Debug, Clone, PartialEq)] -pub enum ScopeKind { - Global, - Local, - Loop, - Function(Type), -} - -#[derive(Debug, Clone, PartialEq)] -pub struct ScopeImpl { - pub type_table: TypeTable, - pub symbol_table: SymbolTable, - pub kind: ScopeKind, -} - -impl ScopeImpl { - pub fn new(kind: ScopeKind) -> Self { - Self { - type_table: TypeTable::new(), - symbol_table: SymbolTable::new(), - kind, - } - } -} - -#[derive(Debug, Clone, PartialEq)] -pub struct Scope(Vec); - -impl Scope { - pub fn new() -> Self { - Self(Vec::new()) - } - - pub fn enter_new(&mut self, kind: ScopeKind) { - self.0.push(ScopeImpl::new(kind)); - } - - pub fn enter(&mut self, scope_impl: ScopeImpl) { - self.0.push(scope_impl); - } - - pub fn leave(&mut self) -> ScopeImpl { - if self.0.len() > 1 { - self.0.pop().unwrap() - } else { - panic!("Can't leave outermost scope"); - } - } - - pub fn find_symbol(&self, name: &str) -> Option<&Symbol> { - for scope in &self.0 { - let symbol = scope.symbol_table.find(name); - - if symbol.is_some() { - return symbol; - } - } - - return None; - } - - pub fn find_symbol_mut(&mut self, name: &str) -> Option<&mut Symbol> { - for scope in &mut self.0 { - let symbol = scope.symbol_table.find_mut(name); - - if symbol.is_some() { - return symbol; - } - } - - return None; - } - - pub fn find_type(&self, name: &str) -> Option<&tt::Type> { - for scope in &self.0 { - let type_ = scope.type_table.find(name); - - if type_.is_some() { - return type_; - } - } - - return None; - } - - pub fn symbol_table_mut(&mut self) -> &mut SymbolTable { - &mut self.0.last_mut().unwrap().symbol_table - } - - pub fn type_table(&self) -> &TypeTable { - &self.0.last().unwrap().type_table - } - - pub fn type_table_mut(&mut self) -> &mut TypeTable { - &mut self.0.last_mut().unwrap().type_table - } - - pub fn kind(&self) -> &ScopeKind { - &self.0.last().unwrap().kind - } - - pub fn local(&self) -> bool { - self.0.len() > 1 - } - - pub fn return_type(&self) -> Option<&Type> { - self.0.iter().rev().find_map(|scope_impl| { - if let ScopeKind::Function(type_) = &scope_impl.kind { - return Some(type_); - } else { - None - } - }) - } -} diff --git a/src/symbol_table/error.rs b/src/symbol_table/error.rs deleted file mode 100644 index 7bf4e9b..0000000 --- a/src/symbol_table/error.rs +++ /dev/null @@ -1,12 +0,0 @@ -use crate::types::TypeError; -use thiserror::Error; - -#[derive(Error, Debug, PartialEq)] -pub enum SymbolTableError { - #[error("Redeclaration of '{0}'")] - Redeclaration(String), - #[error("Symbol '{0}' not found")] - NotFound(String), - #[error(transparent)] - Type(TypeError), -} diff --git a/src/symbol_table/mod.rs b/src/symbol_table/mod.rs deleted file mode 100644 index 4ed64dd..0000000 --- a/src/symbol_table/mod.rs +++ /dev/null @@ -1,7 +0,0 @@ -mod error; -mod symbol_table; - -pub use error::SymbolTableError; -pub use symbol_table::{ - Symbol, SymbolFunction, SymbolGlobal, SymbolLocal, SymbolParam, SymbolTable, -}; diff --git a/src/symbol_table/symbol_table.rs b/src/symbol_table/symbol_table.rs deleted file mode 100644 index ea2b0dc..0000000 --- a/src/symbol_table/symbol_table.rs +++ /dev/null @@ -1,91 +0,0 @@ -use super::SymbolTableError; -use crate::{codegen::Offset, types::Type}; -use std::collections::HashMap; - -#[derive(Debug, Clone, PartialEq)] -pub enum Symbol { - Global(SymbolGlobal), - Local(SymbolLocal), - Param(SymbolParam), - Function(SymbolFunction), -} - -impl Symbol { - pub fn type_(&self) -> Type { - match self { - Self::Global(global) => global.type_.clone(), - Self::Local(local) => local.type_.clone(), - Self::Param(param) => param.type_.clone(), - Self::Function(func) => { - Type::Fn(func.parameters.clone(), Box::new(func.return_type.clone())) - } - } - } - - pub fn function_unchecked(&self) -> &SymbolFunction { - match self { - Symbol::Function(symbol) => symbol, - _ => unreachable!(), - } - } -} - -#[derive(Debug, Clone, PartialEq)] -pub struct SymbolGlobal { - pub label: String, - pub type_: Type, -} - -#[derive(Debug, Clone, PartialEq)] -pub struct SymbolLocal { - pub offset: Offset, - pub type_: Type, -} - -#[derive(Debug, Clone, PartialEq)] -pub struct SymbolParam { - pub preceding: Vec, - pub type_: Type, - pub offset: Offset, -} - -#[derive(Debug, Clone, PartialEq)] -pub struct SymbolFunction { - pub return_type: Type, - pub parameters: Vec, -} - -#[derive(Debug, Clone, PartialEq)] -pub struct SymbolTable(HashMap); - -impl SymbolTable { - const MAX_SYMBOLS: usize = 512; - - pub fn new() -> Self { - Self(HashMap::new()) - } - - pub fn find(&self, name: &str) -> Option<&Symbol> { - self.0.get(name) - } - - pub fn find_mut(&mut self, name: &str) -> Option<&mut Symbol> { - self.0.get_mut(name) - } - - pub fn iter_mut(&mut self) -> impl Iterator { - self.0.values_mut().into_iter() - } - - pub fn push(&mut self, name: String, symbol: Symbol) -> Result<(), SymbolTableError> { - assert!(self.0.len() < Self::MAX_SYMBOLS); - - if self.0.contains_key(&name) { - Err(SymbolTableError::Redeclaration(name)) - } else { - self.0.insert(name, symbol); - - Ok(()) - } - } -} diff --git a/src/ty_problem/mod.rs b/src/ty_problem/mod.rs new file mode 100644 index 0000000..7d72021 --- /dev/null +++ b/src/ty_problem/mod.rs @@ -0,0 +1,208 @@ +use crate::{ + ir::{self, Ty}, + Context, +}; + +#[derive(Debug, Clone, Copy, PartialEq)] +pub struct Id(usize); + +#[derive(Debug, Clone)] +pub enum TyVar<'ir> { + Typed(&'ir Ty<'ir>), + Infer(Id), +} + +impl<'ir> TyVar<'ir> { + pub fn ty(&self) -> Option<&'ir Ty<'ir>> { + match self { + Self::Typed(ty) => Some(ty), + Self::Infer(_) => None, + } + } +} + +impl<'ir> From<&'ir Ty<'ir>> for TyVar<'ir> { + fn from(value: &'ir Ty<'ir>) -> Self { + match value { + Ty::Infer(id) => Self::Infer(*id), + ty => Self::Typed(ty), + } + } +} + +#[derive(Debug)] +enum Constraint { + Eq(Id, Id), + BinAdd { expr: Id, lhs: Id, rhs: Id }, + BinSub { expr: Id, lhs: Id, rhs: Id }, +} + +#[derive(Debug)] +pub struct TyProblem<'ir> { + ty_vars: Vec>, + constraints: Vec, +} + +impl<'ir> TyProblem<'ir> { + pub fn new() -> Self { + Self { + ty_vars: Vec::new(), + constraints: Vec::new(), + } + } + + fn new_ty_var(&mut self, ty_var: TyVar<'ir>) -> Id { + let i = self.ty_vars.len(); + + self.ty_vars.push(ty_var); + + Id(i) + } + + pub fn new_infer_ty_var(&mut self) -> Id { + let i = self.ty_vars.len(); + + self.new_ty_var(TyVar::Infer(Id(i))) + } + + pub fn new_typed_ty_var(&mut self, ty: &'ir Ty<'ir>) -> Id { + self.new_ty_var(TyVar::Typed(ty)) + } + + pub fn get_ty_var(&self, id: Id) -> &TyVar<'ir> { + &self.ty_vars[id.0] + } + + pub fn get_ty_var_mut(&mut self, id: Id) -> &mut TyVar<'ir> { + &mut self.ty_vars[id.0] + } + + pub fn eq(&mut self, lhs: Id, rhs: Id) { + self.constraints.push(Constraint::Eq(lhs, rhs)); + } + + pub fn bin_add(&mut self, expr: Id, lhs: Id, rhs: Id) { + self.constraints.push(Constraint::BinAdd { expr, lhs, rhs }); + } + + pub fn bin_sub(&mut self, expr: Id, lhs: Id, rhs: Id) { + self.constraints.push(Constraint::BinSub { expr, lhs, rhs }); + } + + fn unify(&mut self, lhs: TyVar<'ir>, rhs: TyVar<'ir>) -> bool { + match (lhs, rhs) { + (TyVar::Infer(lhs), TyVar::Infer(rhs)) => { + self.eq(lhs, rhs); + + false + } + (TyVar::Typed(ty), TyVar::Infer(id)) | (TyVar::Infer(id), TyVar::Typed(ty)) => { + *self.get_ty_var_mut(id) = TyVar::Typed(ty); + + true + } + (TyVar::Typed(lhs), TyVar::Typed(rhs)) => match (lhs, rhs) { + (Ty::Ptr(lhs), Ty::Ptr(rhs)) => self.unify((*lhs).into(), (*rhs).into()), + _ => { + assert_eq!(lhs, rhs, "Failed to unify {lhs} and {rhs}"); + + false + } + }, + } + } + + fn apply_constraints(&mut self) -> bool { + let mut constraints = std::mem::take(&mut self.constraints); + let mut progress = false; + + constraints.retain(|constraint| match constraint { + Constraint::Eq(lhs, rhs) => { + progress |= + self.unify(self.get_ty_var(*lhs).clone(), self.get_ty_var(*rhs).clone()); + + false + } + Constraint::BinAdd { expr, lhs, rhs } => { + let (lhs, rhs) = if let Some(Ty::Ptr(_)) = self.get_ty_var(*rhs).ty() { + (rhs, lhs) + } else { + (lhs, rhs) + }; + + if let Some(ty) = self.get_ty_var(*lhs).ty() { + match ty { + Ty::Ptr(_) => { + *self.get_ty_var_mut(*rhs) = TyVar::Typed(&Ty::Int(ir::IntTy::Isize)); + *self.get_ty_var_mut(*expr) = TyVar::Typed(&Ty::Int(ir::IntTy::Isize)); + progress |= true; + } + Ty::Int(_) | Ty::UInt(_) => { + *self.get_ty_var_mut(*rhs) = TyVar::Typed(ty); + *self.get_ty_var_mut(*expr) = TyVar::Typed(ty); + progress |= true; + } + _ => unreachable!("Bad type, expected integer or pointer, got {}", ty), + }; + + false + } else { + self.eq(*expr, *lhs); + + true + } + } + Constraint::BinSub { expr, lhs, rhs } => { + let (lhs, rhs) = if let Some(Ty::Ptr(_)) = self.get_ty_var(*rhs).ty() { + (rhs, lhs) + } else { + (lhs, rhs) + }; + + match (self.get_ty_var(*lhs).ty(), self.get_ty_var(*rhs).ty()) { + (Some(Ty::Ptr(_)), Some(Ty::Ptr(_))) => { + *self.get_ty_var_mut(*expr) = TyVar::Typed(&Ty::Int(ir::IntTy::Isize)); + progress |= true; + + false + } + (Some(ty @ Ty::Ptr(_)), Some(Ty::Int(_) | Ty::UInt(_))) => { + *self.get_ty_var_mut(*expr) = TyVar::Typed(ty); + progress |= true; + + false + } + (None, None) => { + self.eq(*expr, *lhs); + self.eq(*lhs, *rhs); + + progress |= true; + + true + } + _ => true, + } + } + }); + self.constraints.append(&mut constraints); + + progress + } + + pub fn solve(&mut self) { + loop { + if !self.apply_constraints() { + break; + } + } + } + + pub fn resolve_ty(&self, ctx: &Context<'ir>, ty: &'ir Ty<'ir>) -> &'ir Ty<'ir> { + // TODO: check if there already exist such a type instead of allocationg a new one + match ty { + Ty::Infer(id) => self.resolve_ty(ctx, self.get_ty_var(*id).ty().unwrap()), + Ty::Ptr(ty) => ctx.allocator.alloc(Ty::Ptr(self.resolve_ty(ctx, *ty))), + ty => ty, + } + } +} diff --git a/src/type_table.rs b/src/type_table.rs deleted file mode 100644 index 1a58eab..0000000 --- a/src/type_table.rs +++ /dev/null @@ -1,81 +0,0 @@ -use crate::{ - archs::Arch, - codegen::Offset, - scope::Scope, - types::{self, TypeError}, -}; - -#[derive(Debug, Clone, PartialEq)] -pub enum Type { - Struct(TypeStruct), -} - -#[derive(Debug, Clone, PartialEq)] -pub struct TypeStruct { - pub name: String, - pub fields: Vec<(String, types::Type)>, - pub methods: Vec, -} - -#[derive(Debug, Clone, PartialEq)] -pub struct TypeStructMethod { - pub return_type: types::Type, - pub name: String, - pub params: Vec<(String, types::Type)>, -} - -impl TypeStruct { - pub fn offset(&self, arch: &Arch, name: &str, scope: &Scope) -> Result { - let mut offset: usize = 0; - - for (field_name, type_) in &self.fields { - let size = arch.size(type_, scope); - offset = offset.next_multiple_of(size); - if name == field_name { - break; - } - - offset += size; - } - - Ok(Offset(offset as isize)) - } - - pub fn get_field_type(&self, field: &str) -> Option<&types::Type> { - self.fields - .iter() - .find(|(name, _)| name == field) - .map(|(_, type_)| type_) - } - - pub fn find_method(&self, name: &str) -> Option<&TypeStructMethod> { - self.methods.iter().find(|method| method.name == name) - } - - pub fn contains(&self, field: &str) -> bool { - self.fields.iter().any(|(name, _)| name == field) - } - - pub fn types(&self) -> Vec<&types::Type> { - self.fields.iter().map(|(_, type_)| type_).collect() - } -} - -#[derive(Debug, Clone, PartialEq)] -pub struct TypeTable(pub Vec); - -impl TypeTable { - pub fn new() -> Self { - Self(Vec::new()) - } - - pub fn define(&mut self, type_: Type) { - self.0.push(type_); - } - - pub fn find(&self, type_name: &str) -> Option<&Type> { - self.0.iter().find(|type_| match type_ { - Type::Struct(type_struct) => type_struct.name == type_name, - }) - } -} diff --git a/src/types/error.rs b/src/types/error.rs deleted file mode 100644 index d332af2..0000000 --- a/src/types/error.rs +++ /dev/null @@ -1,24 +0,0 @@ -use super::Type; -use thiserror::Error; - -#[derive(Error, Debug, PartialEq)] -pub enum TypeError { - #[error("Operation between {0} and {1} are not allowed")] - Promotion(Type, Type), - #[error("Ident {0} not found")] - IdentNotFound(String), - #[error("Can't assign {0} to {1}")] - Assignment(Type, Type), - #[error("Can't cast {0} into {1}")] - Cast(Type, Type), - #[error("Expected return value of type {1}, got {0} instead")] - Return(Type, Type), - #[error("Variable can't be of type void")] - VoidVariable, - #[error("Type '{0}' doens't exits")] - Nonexistent(String), - #[error("Type {0} is not pointer")] - Deref(Type), - #[error("Mismatched types expected {0}, found {1}")] - Mismatched(Type, Type), -} diff --git a/src/types/mod.rs b/src/types/mod.rs deleted file mode 100644 index 1e40b48..0000000 --- a/src/types/mod.rs +++ /dev/null @@ -1,5 +0,0 @@ -mod error; -mod types; - -pub use error::TypeError; -pub use types::{IntType, Type, TypeArray, UintType};