diff --git a/crates/rue-compiler/src/compiler.rs b/crates/rue-compiler/src/compiler.rs index bbe7946..ad7cc33 100644 --- a/crates/rue-compiler/src/compiler.rs +++ b/crates/rue-compiler/src/compiler.rs @@ -262,7 +262,7 @@ impl<'a> Compiler<'a> { } fn symbol_type(&self, guard_path: &GuardPath) -> Option { - for guards in &self.type_guard_stack { + for guards in self.type_guard_stack.iter().rev() { if let Some(guard) = guards.get(guard_path) { return Some(*guard); } diff --git a/crates/rue-compiler/src/compiler/block.rs b/crates/rue-compiler/src/compiler/block.rs index 896bf0e..83d103d 100644 --- a/crates/rue-compiler/src/compiler/block.rs +++ b/crates/rue-compiler/src/compiler/block.rs @@ -8,9 +8,22 @@ use crate::{ use super::{stmt::Statement, Compiler}; +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum BlockTerminator { + Implicit, + Return, + Raise, +} + +#[derive(Debug, Clone)] +pub struct BlockSummary { + pub value: Value, + pub terminator: BlockTerminator, +} + impl Compiler<'_> { /// Compile a block expression into the current scope, returning the HIR and whether there was an explicit return. - pub fn compile_block(&mut self, block: &Block, expected_type: Option) -> (Value, bool) { + pub fn compile_block(&mut self, block: &Block, expected_type: Option) -> BlockSummary { // Compile all of the items in the block first. // This means that statements can use item symbols in any order, // but items cannot use statement symbols. @@ -19,7 +32,7 @@ impl Compiler<'_> { self.compile_items(&items, declarations); let mut statements = Vec::new(); - let mut explicit_return = false; + let mut terminator = BlockTerminator::Implicit; let mut is_terminated = block.expr().is_some(); for stmt in block.stmts() { @@ -53,7 +66,7 @@ impl Compiler<'_> { return_stmt.syntax().text_range(), ); - explicit_return = true; + terminator = BlockTerminator::Return; is_terminated = true; statements.push(Statement::Return(value)); @@ -67,6 +80,7 @@ impl Compiler<'_> { let hir_id = self.db.alloc_hir(Hir::Raise(value)); + terminator = BlockTerminator::Raise; is_terminated = true; statements.push(Statement::Return(Value::new(hir_id, self.builtins.unknown))); @@ -88,6 +102,7 @@ impl Compiler<'_> { // If the condition is false, we raise an error. // So we can assume that the condition is true from this point on. // This will be popped in reverse order later after all statements have been lowered. + self.type_guard_stack.push(condition.then_guards()); let not_condition = self.db.alloc_hir(Hir::Op(Op::Not, condition.hir_id)); @@ -156,6 +171,9 @@ impl Compiler<'_> { } } - (body, explicit_return) + BlockSummary { + value: body, + terminator, + } } } diff --git a/crates/rue-compiler/src/compiler/expr/block_expr.rs b/crates/rue-compiler/src/compiler/expr/block_expr.rs index 1974d5b..8d63073 100644 --- a/crates/rue-compiler/src/compiler/expr/block_expr.rs +++ b/crates/rue-compiler/src/compiler/expr/block_expr.rs @@ -1,20 +1,25 @@ use rue_parser::{AstNode, Block}; -use crate::{compiler::Compiler, scope::Scope, value::Value, ErrorKind, TypeId}; +use crate::{ + compiler::{block::BlockTerminator, Compiler}, + scope::Scope, + value::Value, + ErrorKind, TypeId, +}; impl Compiler<'_> { pub fn compile_block_expr(&mut self, block: &Block, expected_type: Option) -> Value { let scope_id = self.db.alloc_scope(Scope::default()); self.scope_stack.push(scope_id); - let (value, explicit_return) = self.compile_block(block, expected_type); + let summary = self.compile_block(block, expected_type); self.scope_stack.pop().unwrap(); - if explicit_return { + if summary.terminator == BlockTerminator::Return { self.db .error(ErrorKind::ExplicitReturnInExpr, block.syntax().text_range()); } - value + summary.value } } diff --git a/crates/rue-compiler/src/compiler/item/function_item.rs b/crates/rue-compiler/src/compiler/item/function_item.rs index e230040..f6cd53c 100644 --- a/crates/rue-compiler/src/compiler/item/function_item.rs +++ b/crates/rue-compiler/src/compiler/item/function_item.rs @@ -183,7 +183,7 @@ impl Compiler<'_> { // We don't care about explicit returns in this context. self.scope_stack.push(scope_id); self.allow_generic_inference_stack.push(false); - let value = self.compile_block(&body, Some(ty.return_type)).0; + let value = self.compile_block(&body, Some(ty.return_type)).value; self.allow_generic_inference_stack.pop().unwrap(); self.scope_stack.pop().unwrap(); diff --git a/crates/rue-compiler/src/compiler/stmt/if_stmt.rs b/crates/rue-compiler/src/compiler/stmt/if_stmt.rs index ff930f4..0495bc4 100644 --- a/crates/rue-compiler/src/compiler/stmt/if_stmt.rs +++ b/crates/rue-compiler/src/compiler/stmt/if_stmt.rs @@ -3,7 +3,7 @@ use std::collections::HashMap; use rue_parser::{AstNode, IfStmt}; use crate::{ - compiler::Compiler, + compiler::{block::BlockTerminator, Compiler}, scope::Scope, value::{GuardPath, TypeOverride}, ErrorKind, HirId, TypeId, @@ -38,7 +38,7 @@ impl Compiler<'_> { // Compile the then block. self.scope_stack.push(scope_id); - let (value, explicit_return) = self.compile_block(&then_block, expected_type); + let summary = self.compile_block(&then_block, expected_type); self.scope_stack.pop().unwrap(); // Pop the type guards, since we've left the scope. @@ -46,14 +46,14 @@ impl Compiler<'_> { // If there's an implicit return, we want to raise an error. // This could technically work but makes the intent of the code unclear. - if !explicit_return { + if summary.terminator == BlockTerminator::Implicit { self.db.error( ErrorKind::ImplicitReturnInIf, then_block.syntax().text_range(), ); } - value + summary.value } else { self.unknown() }; diff --git a/crates/rue-compiler/src/lib.rs b/crates/rue-compiler/src/lib.rs index 540d3aa..2d8a7fd 100644 --- a/crates/rue-compiler/src/lib.rs +++ b/crates/rue-compiler/src/lib.rs @@ -37,7 +37,7 @@ pub fn compile(allocator: &mut Allocator, root: &Root, mut should_codegen: bool) let main_module_id = load_module(&mut ctx, root); let symbol_table = compile_modules(ctx); - let main = try_export_main(&mut db, main_module_id).expect("missing main function"); + let main = try_export_main(&mut db, main_module_id); let graph = build_graph( &mut db, &symbol_table, @@ -50,7 +50,12 @@ pub fn compile(allocator: &mut Allocator, root: &Root, mut should_codegen: bool) Output { diagnostics: db.diagnostics().to_vec(), node_ptr: if should_codegen { - codegen(allocator, &mut db, &graph, main) + codegen( + allocator, + &mut db, + &graph, + main.expect("missing main function"), + ) } else { NodePtr::default() }, diff --git a/tests/enum/enum_type_guard.rue b/tests/enum/enum_type_guard.rue new file mode 100644 index 0000000..ff3c3cd --- /dev/null +++ b/tests/enum/enum_type_guard.rue @@ -0,0 +1,21 @@ +enum Color { + Red, + Green, + Blue, +} + +fun main() -> Int { + let color: Color = Color::Red; + + if color is Color::Green { + raise "Unreachable"; + } + + if color is Color::Blue { + raise "Unreachable"; + } + + assert color is Color::Red; + let red: Color::Red = color; + red as Int +}