diff --git a/crates/rue-compiler/src/codegen.rs b/crates/rue-compiler/src/codegen.rs index 2df8a28..2f608f8 100644 --- a/crates/rue-compiler/src/codegen.rs +++ b/crates/rue-compiler/src/codegen.rs @@ -19,6 +19,7 @@ struct Ops { f: NodePtr, r: NodePtr, l: NodePtr, + x: NodePtr, eq: NodePtr, sha256: NodePtr, strlen: NodePtr, @@ -43,6 +44,7 @@ impl<'a> Codegen<'a> { f: allocator.new_small_number(5).unwrap(), r: allocator.new_small_number(6).unwrap(), l: allocator.new_small_number(7).unwrap(), + x: allocator.new_small_number(8).unwrap(), eq: allocator.new_small_number(9).unwrap(), sha256: allocator.new_small_number(11).unwrap(), strlen: allocator.new_small_number(13).unwrap(), @@ -70,6 +72,7 @@ impl<'a> Codegen<'a> { Lir::FunctionBody(body) => self.gen_quote(body), Lir::First(value) => self.gen_first(value), Lir::Rest(value) => self.gen_rest(value), + Lir::Raise => self.list(&[self.ops.x]), Lir::Sha256(value) => self.gen_sha256(value), Lir::IsCons(value) => self.gen_is_cons(value), Lir::Strlen(value) => self.gen_strlen(value), diff --git a/crates/rue-compiler/src/error.rs b/crates/rue-compiler/src/error.rs index d63ce5a..dbdd6ae 100644 --- a/crates/rue-compiler/src/error.rs +++ b/crates/rue-compiler/src/error.rs @@ -115,6 +115,12 @@ pub enum DiagnosticInfo { #[error("redundant check against same type `{0}`")] RedundantTypeGuard(String), + + #[error("implicit return is not allowed in if statements, use an explicit return statement")] + ImplicitReturnInIf, + + #[error("explicit return is not allowed in expressions")] + ExplicitReturnInExpr, } /// Join a list of names into a string, wrapped in backticks. diff --git a/crates/rue-compiler/src/hir.rs b/crates/rue-compiler/src/hir.rs index 93d25a2..76d73e8 100644 --- a/crates/rue-compiler/src/hir.rs +++ b/crates/rue-compiler/src/hir.rs @@ -24,6 +24,7 @@ pub enum Hir { First(HirId), Rest(HirId), Not(HirId), + Raise, Sha256(HirId), IsCons(HirId), Strlen(HirId), diff --git a/crates/rue-compiler/src/lir.rs b/crates/rue-compiler/src/lir.rs index 612c7d6..42b2dfd 100644 --- a/crates/rue-compiler/src/lir.rs +++ b/crates/rue-compiler/src/lir.rs @@ -11,6 +11,7 @@ pub enum Lir { FunctionBody(LirId), First(LirId), Rest(LirId), + Raise, Sha256(LirId), IsCons(LirId), Strlen(LirId), diff --git a/crates/rue-compiler/src/lowerer.rs b/crates/rue-compiler/src/lowerer.rs index 4550582..f076373 100644 --- a/crates/rue-compiler/src/lowerer.rs +++ b/crates/rue-compiler/src/lowerer.rs @@ -10,7 +10,7 @@ use rowan::TextRange; use rue_parser::{ AstNode, BinaryExpr, BinaryOp, Block, CastExpr, ConstItem, EnumItem, Expr, FieldAccess, FunctionCall, FunctionItem, FunctionType as AstFunctionType, GroupExpr, GuardExpr, IfExpr, - IndexAccess, InitializerExpr, InitializerField, Item, LambdaExpr, ListExpr, ListType, + IndexAccess, InitializerExpr, InitializerField, Item, LambdaExpr, LetStmt, ListExpr, ListType, LiteralExpr, PairExpr, PairType, Path, PrefixExpr, PrefixOp, Root, Stmt, StructField, StructItem, SyntaxKind, SyntaxToken, Type as AstType, TypeAliasItem, }; @@ -279,7 +279,8 @@ impl<'a> Lowerer<'a> { }; self.scope_stack.push(scope_id); - let output = self.compile_block_expr(body, None, Some(ty.return_type())); + let (output, _explicit_return) = + self.compile_block_expr(body, None, Some(ty.return_type())); self.scope_stack.pop().unwrap(); self.type_check( @@ -387,49 +388,140 @@ impl<'a> Lowerer<'a> { } } + fn compile_let_stmt(&mut self, let_stmt: LetStmt) -> Option { + let expected_type = let_stmt.ty().map(|ty| self.compile_type(ty)); + + let value = let_stmt + .expr() + .map(|expr| self.compile_expr(expr, expected_type)) + .unwrap_or(self.unknown()); + + if let Some(expected_type) = expected_type { + self.type_check(value.ty(), expected_type, let_stmt.syntax().text_range()); + } + + let Some(name) = let_stmt.name() else { + return None; + }; + + let symbol_id = self.db.alloc_symbol(Symbol::LetBinding { + type_id: expected_type.unwrap_or(value.ty()), + hir_id: value.hir(), + }); + + let mut let_scope = Scope::default(); + let_scope.define_symbol(name.to_string(), symbol_id); + let scope_id = self.db.alloc_scope(let_scope); + self.scope_stack.push(scope_id); + + Some(scope_id) + } + fn compile_block_expr( &mut self, block: Block, scope_id: Option, expected_type: Option, - ) -> Value { + ) -> (Value, bool) { if let Some(scope_id) = scope_id { self.scope_stack.push(scope_id); } self.compile_items(block.items()); - let mut let_scope_ids = Vec::new(); + enum Statement { + Let(ScopeId), + If(HirId, HirId), + Return(Value), + } + + let mut statements = Vec::new(); for stmt in block.stmts() { match stmt { Stmt::LetStmt(let_stmt) => { - let expected_type = let_stmt.ty().map(|ty| self.compile_type(ty)); + let Some(scope_id) = self.compile_let_stmt(let_stmt) else { + continue; + }; + statements.push(Statement::Let(scope_id)); + } + Stmt::IfStmt(if_stmt) => { + let condition = if_stmt + .condition() + .map(|condition| self.compile_expr(condition, Some(self.bool_type))); - let value = let_stmt - .expr() - .map(|expr| self.compile_expr(expr, expected_type)) - .unwrap_or(self.unknown()); + if let Some(condition) = condition.as_ref() { + self.type_guards.push(condition.then_guards()); + } - if let Some(expected_type) = expected_type { - self.type_check(value.ty(), expected_type, let_stmt.syntax().text_range()); + let then_block = if_stmt + .then_block() + .map(|then_block| { + let scope_id = self.db.alloc_scope(Scope::default()); + let (value, explicit_return) = self.compile_block_expr( + then_block.clone(), + Some(scope_id), + expected_type, + ); + if !explicit_return { + self.error( + DiagnosticInfo::ImplicitReturnInIf, + then_block.syntax().text_range(), + ); + } + value + }) + .unwrap_or_else(|| self.unknown()); + + if condition.is_some() { + self.type_guards.pop().unwrap(); } - let Some(name) = let_stmt.name() else { - continue; - }; + self.type_check( + then_block.ty(), + expected_type.unwrap_or(self.unknown_type), + block.syntax().text_range(), + ); + + let else_guards = condition + .as_ref() + .map(|condition| condition.else_guards()) + .unwrap_or_default(); + + self.type_guards.push(else_guards); + + statements.push(Statement::If( + condition + .map(|condition| condition.hir()) + .unwrap_or(self.unknown_hir), + then_block.hir(), + )); + } + Stmt::ReturnStmt(return_stmt) => { + let value = return_stmt + .expr() + .map(|expr| self.compile_expr(expr, expected_type)) + .unwrap_or_else(|| self.unknown()); + statements.push(Statement::Return(value)); + } + Stmt::AssertStmt(assert_stmt) => { + let condition = assert_stmt + .expr() + .map(|condition| self.compile_expr(condition, Some(self.bool_type))) + .unwrap_or_else(|| self.unknown()); - let symbol_id = self.db.alloc_symbol(Symbol::LetBinding { - type_id: expected_type.unwrap_or(value.ty()), - hir_id: value.hir(), - }); + self.type_guards.push(condition.then_guards()); - let mut let_scope = Scope::default(); - let_scope.define_symbol(name.to_string(), symbol_id); - let scope_id = self.db.alloc_scope(let_scope); - self.scope_stack.push(scope_id); + self.type_check( + condition.ty(), + self.bool_type, + assert_stmt.syntax().text_range(), + ); + + let not_condition = self.db.alloc_hir(Hir::Not(condition.hir())); + let raise = self.db.alloc_hir(Hir::Raise); - let_scope_ids.push(scope_id); + statements.push(Statement::If(not_condition, raise)) } } } @@ -439,22 +531,44 @@ impl<'a> Lowerer<'a> { .map(|expr| self.compile_expr(expr, expected_type)) .unwrap_or(self.unknown()); - for scope_id in let_scope_ids.into_iter().rev() { - body = Value::typed( - self.db.alloc_hir(Hir::Scope { - scope_id, - value: body.hir(), - }), - body.ty(), - ); - self.scope_stack.pop().unwrap(); + let mut explicit_return = false; + + for statement in statements.into_iter().rev() { + match statement { + Statement::Let(scope_id) => { + body = Value::typed( + self.db.alloc_hir(Hir::Scope { + scope_id, + value: body.hir(), + }), + body.ty(), + ); + self.scope_stack.pop().unwrap(); + } + Statement::Return(value) => { + body = value; + explicit_return = true; + } + Statement::If(condition, then_block) => { + self.type_guards.pop().unwrap(); + + body = Value::typed( + self.db.alloc_hir(Hir::If { + condition, + then_block, + else_block: body.hir(), + }), + body.ty(), + ); + } + } } if scope_id.is_some() { self.scope_stack.pop().unwrap(); } - body + (body, explicit_return) } fn compile_expr(&mut self, expr: Expr, expected_type: Option) -> Value { @@ -466,7 +580,15 @@ impl<'a> Lowerer<'a> { Expr::PairExpr(pair) => self.compile_pair_expr(pair, expected_type), Expr::Block(block) => { let scope_id = self.db.alloc_scope(Scope::default()); - self.compile_block_expr(block, Some(scope_id), expected_type) + let (value, explicit_return) = + self.compile_block_expr(block.clone(), Some(scope_id), expected_type); + if explicit_return { + self.error( + DiagnosticInfo::ExplicitReturnInExpr, + block.syntax().text_range(), + ); + } + value } Expr::LambdaExpr(lambda) => self.compile_lambda_expr(lambda, expected_type), Expr::PrefixExpr(prefix) => self.compile_prefix_expr(prefix), @@ -1100,7 +1222,7 @@ impl<'a> Lowerer<'a> { fn compile_if_expr(&mut self, if_expr: IfExpr, expected_type: Option) -> Value { let condition = if_expr .condition() - .map(|condition| self.compile_expr(condition, None)); + .map(|condition| self.compile_expr(condition, Some(self.bool_type))); if let Some(condition) = condition.as_ref() { self.type_guards.push(condition.then_guards()); @@ -1109,6 +1231,7 @@ impl<'a> Lowerer<'a> { let then_block = if_expr.then_block().map(|then_block| { let scope_id = self.db.alloc_scope(Scope::default()); self.compile_block_expr(then_block, Some(scope_id), expected_type) + .0 }); if condition.is_some() { @@ -1119,9 +1242,13 @@ impl<'a> Lowerer<'a> { self.type_guards.push(condition.else_guards()); } + let expected_type = + expected_type.or_else(|| then_block.as_ref().map(|then_block| then_block.ty())); + let else_block = if_expr.else_block().map(|else_block| { let scope_id = self.db.alloc_scope(Scope::default()); self.compile_block_expr(else_block, Some(scope_id), expected_type) + .0 }); if condition.is_some() { diff --git a/crates/rue-compiler/src/optimizer.rs b/crates/rue-compiler/src/optimizer.rs index 9793790..55bcd5c 100644 --- a/crates/rue-compiler/src/optimizer.rs +++ b/crates/rue-compiler/src/optimizer.rs @@ -39,7 +39,7 @@ impl<'a> Optimizer<'a> { fn compute_captures_hir(&mut self, scope_id: ScopeId, hir_id: HirId) { match self.db.hir(hir_id).clone() { Hir::Unknown => unreachable!(), - Hir::Atom(_) => {} + Hir::Atom(_) | Hir::Raise => {} Hir::Reference(symbol_id) => self.compute_reference_captures(scope_id, symbol_id), Hir::Scope { scope_id: new_scope_id, @@ -317,6 +317,7 @@ impl<'a> Optimizer<'a> { Hir::First(value) => self.opt_first(scope_id, *value), Hir::Rest(value) => self.opt_rest(scope_id, *value), Hir::Not(value) => self.opt_not(scope_id, *value), + Hir::Raise => self.db.alloc_lir(Lir::Raise), Hir::Sha256(value) => self.opt_sha256(scope_id, *value), Hir::IsCons(value) => self.opt_is_cons(scope_id, *value), Hir::Strlen(value) => self.opt_strlen(scope_id, *value), diff --git a/crates/rue-lexer/src/lib.rs b/crates/rue-lexer/src/lib.rs index e2c2a4f..460cda2 100644 --- a/crates/rue-lexer/src/lib.rs +++ b/crates/rue-lexer/src/lib.rs @@ -111,6 +111,8 @@ impl<'a> Lexer<'a> { "const" => TokenKind::Const, "if" => TokenKind::If, "else" => TokenKind::Else, + "return" => TokenKind::Return, + "assert" => TokenKind::Assert, "nil" => TokenKind::Nil, "true" => TokenKind::True, "false" => TokenKind::False, @@ -275,6 +277,7 @@ mod tests { check("const", &[TokenKind::Const]); check("if", &[TokenKind::If]); check("else", &[TokenKind::Else]); + check("return", &[TokenKind::Return]); check("true", &[TokenKind::True]); check("false", &[TokenKind::False]); check("nil", &[TokenKind::Nil]); diff --git a/crates/rue-lexer/src/token_kind.rs b/crates/rue-lexer/src/token_kind.rs index 8114d45..f29e293 100644 --- a/crates/rue-lexer/src/token_kind.rs +++ b/crates/rue-lexer/src/token_kind.rs @@ -19,6 +19,8 @@ pub enum TokenKind { Const, If, Else, + Return, + Assert, Nil, True, False, diff --git a/crates/rue-parser/src/ast.rs b/crates/rue-parser/src/ast.rs index e3f0a91..a3f9441 100644 --- a/crates/rue-parser/src/ast.rs +++ b/crates/rue-parser/src/ast.rs @@ -120,8 +120,11 @@ ast_node!(PairType); ast_node!(FunctionType); ast_node!(FunctionTypeParam); -ast_enum!(Stmt, LetStmt); +ast_enum!(Stmt, LetStmt, IfStmt, ReturnStmt, AssertStmt); ast_node!(LetStmt); +ast_node!(IfStmt); +ast_node!(ReturnStmt); +ast_node!(AssertStmt); impl Root { pub fn items(&self) -> Vec { @@ -317,6 +320,28 @@ impl LetStmt { } } +impl IfStmt { + pub fn condition(&self) -> Option { + self.syntax().children().find_map(Expr::cast) + } + + pub fn then_block(&self) -> Option { + self.syntax().children().find_map(Block::cast) + } +} + +impl ReturnStmt { + pub fn expr(&self) -> Option { + self.syntax().children().find_map(Expr::cast) + } +} + +impl AssertStmt { + pub fn expr(&self) -> Option { + self.syntax().children().find_map(Expr::cast) + } +} + impl InitializerExpr { pub fn path(&self) -> Option { self.syntax().children().find_map(Path::cast) diff --git a/crates/rue-parser/src/grammar.rs b/crates/rue-parser/src/grammar.rs index 51aef75..8822d12 100644 --- a/crates/rue-parser/src/grammar.rs +++ b/crates/rue-parser/src/grammar.rs @@ -135,16 +135,24 @@ fn const_item(p: &mut Parser) { fn block(p: &mut Parser) { p.start(SyntaxKind::Block); p.expect(SyntaxKind::OpenBrace); - loop { + while !p.at(SyntaxKind::CloseBrace) && !p.at(SyntaxKind::Eof) { if p.at(SyntaxKind::Let) { let_stmt(p); + } else if p.at(SyntaxKind::Return) { + return_stmt(p); + } else if p.at(SyntaxKind::If) { + if if_stmt_maybe_else(p, false) { + break; + } + } else if p.at(SyntaxKind::Assert) { + assert_stmt(p); } else if p.at(SyntaxKind::Fun) || p.at(SyntaxKind::Type) || p.at(SyntaxKind::Const) { item(p); } else { + expr(p); break; } } - expr(p); p.expect(SyntaxKind::CloseBrace); p.finish(); } @@ -162,6 +170,40 @@ fn let_stmt(p: &mut Parser) { p.finish(); } +fn if_stmt_maybe_else(p: &mut Parser, expr_only: bool) -> bool { + let cp = p.checkpoint(); + p.expect(SyntaxKind::If); + expr(p); + block(p); + let mut has_else = false; + if expr_only || p.at(SyntaxKind::Else) { + p.start_at(cp, SyntaxKind::IfExpr); + p.expect(SyntaxKind::Else); + block(p); + has_else = true; + } else { + p.start_at(cp, SyntaxKind::IfStmt); + } + p.finish(); + has_else +} + +fn return_stmt(p: &mut Parser) { + p.start(SyntaxKind::ReturnStmt); + p.expect(SyntaxKind::Return); + expr(p); + p.expect(SyntaxKind::Semicolon); + p.finish(); +} + +fn assert_stmt(p: &mut Parser) { + p.start(SyntaxKind::AssertStmt); + p.expect(SyntaxKind::Assert); + expr(p); + p.expect(SyntaxKind::Semicolon); + p.finish(); +} + fn path(p: &mut Parser) { p.start(SyntaxKind::Path); p.expect(SyntaxKind::Ident); @@ -246,7 +288,7 @@ fn expr_binding_power(p: &mut Parser, minimum_binding_power: u8) { p.expect(SyntaxKind::CloseParen); p.finish(); } else if p.at(SyntaxKind::If) { - if_expr(p); + if_stmt_maybe_else(p, true); } else if p.at(SyntaxKind::Fun) { lambda_expr(p); } else { @@ -384,16 +426,6 @@ fn lambda_param(p: &mut Parser) { p.finish(); } -fn if_expr(p: &mut Parser) { - p.start(SyntaxKind::IfExpr); - p.expect(SyntaxKind::If); - expr(p); - block(p); - p.expect(SyntaxKind::Else); - block(p); - p.finish(); -} - const TYPE_RECOVERY_SET: &[SyntaxKind] = &[SyntaxKind::OpenBrace, SyntaxKind::CloseBrace]; fn ty(p: &mut Parser) { diff --git a/crates/rue-parser/src/parser.rs b/crates/rue-parser/src/parser.rs index 57b4135..e1bc452 100644 --- a/crates/rue-parser/src/parser.rs +++ b/crates/rue-parser/src/parser.rs @@ -190,6 +190,8 @@ fn convert_tokens<'a>( TokenKind::Const => SyntaxKind::Const, TokenKind::If => SyntaxKind::If, TokenKind::Else => SyntaxKind::Else, + TokenKind::Return => SyntaxKind::Return, + TokenKind::Assert => SyntaxKind::Assert, TokenKind::Nil => SyntaxKind::Nil, TokenKind::True => SyntaxKind::True, TokenKind::False => SyntaxKind::False, diff --git a/crates/rue-parser/src/syntax_kind.rs b/crates/rue-parser/src/syntax_kind.rs index 025b2d3..f5655ca 100644 --- a/crates/rue-parser/src/syntax_kind.rs +++ b/crates/rue-parser/src/syntax_kind.rs @@ -29,6 +29,8 @@ pub enum SyntaxKind { Const, If, Else, + Return, + Assert, Nil, True, False, @@ -73,6 +75,9 @@ pub enum SyntaxKind { ConstItem, LetStmt, + IfStmt, + ReturnStmt, + AssertStmt, Block, Path, @@ -131,6 +136,8 @@ impl fmt::Display for SyntaxKind { SyntaxKind::Const => "'const'", SyntaxKind::If => "'if'", SyntaxKind::Else => "'else'", + SyntaxKind::Return => "'return'", + SyntaxKind::Assert => "'assert'", SyntaxKind::Nil => "'nil'", SyntaxKind::True => "'true'", SyntaxKind::False => "'false'", @@ -175,6 +182,9 @@ impl fmt::Display for SyntaxKind { SyntaxKind::ConstItem => "const item", SyntaxKind::LetStmt => "let statement", + SyntaxKind::IfStmt => "if statement", + SyntaxKind::ReturnStmt => "return statement", + SyntaxKind::AssertStmt => "assert statement", SyntaxKind::Block => "block", SyntaxKind::Path => "identifier path", diff --git a/tests.toml b/tests.toml index a37aeb7..24d7175 100644 --- a/tests.toml +++ b/tests.toml @@ -116,3 +116,10 @@ cost = 7060 input = "()" output = "6" hash = "c65b731dc627a8ad1357dee68e27e3de33a582bebdb0bc2b503310e20ad360cc" + +[early_return] +bytes = 115 +cost = 2716 +input = "(50)" +output = "\"Small\"" +hash = "cf32bd32ff878903a6962acd218c451474e0ef3ae3c649f79aa5abca1bff730b" diff --git a/tests/early_return.rue b/tests/early_return.rue new file mode 100644 index 0000000..41269e3 --- /dev/null +++ b/tests/early_return.rue @@ -0,0 +1,14 @@ +fun main(num: Int) -> Bytes { + if num > 100 { + return "Large"; + } + + if num > 50 { + return "Medium"; + } + + assert num > 20; + + "Small" +} +