diff --git a/compiler/formatter/src/format.rs b/compiler/formatter/src/format.rs index c32e6dab3..3b6fcd994 100644 --- a/compiler/formatter/src/format.rs +++ b/compiler/formatter/src/format.rs @@ -744,6 +744,7 @@ pub fn format_cst<'a>( } CstKind::MatchCase { pattern, + condition: _, // TODO: format match case conditions arrow, body, } => { diff --git a/compiler/frontend/src/ast.rs b/compiler/frontend/src/ast.rs index 41d7ce34c..82d648b72 100644 --- a/compiler/frontend/src/ast.rs +++ b/compiler/frontend/src/ast.rs @@ -131,6 +131,7 @@ pub struct Match { #[derive(Debug, PartialEq, Eq, Clone, Hash)] pub struct MatchCase { pub pattern: Box, + pub condition: Option>, pub body: Vec, } #[derive(Debug, PartialEq, Eq, Clone, Hash)] @@ -248,7 +249,14 @@ impl FindAst for Match { } impl FindAst for MatchCase { fn find(&self, id: &Id) -> Option<&Ast> { - self.pattern.find(id).or_else(|| self.body.find(id)) + self.pattern + .find(id) + .or_else(|| { + self.condition + .as_ref() + .and_then(|box condition| condition.find(id)) + }) + .or_else(|| self.body.find(id)) } } impl FindAst for OrPattern { @@ -352,8 +360,15 @@ impl CollectErrors for Ast { expression.collect_errors(errors); cases.collect_errors(errors); } - AstKind::MatchCase(MatchCase { pattern, body }) => { + AstKind::MatchCase(MatchCase { + pattern, + condition, + body, + }) => { pattern.collect_errors(errors); + if let Some(box condition) = condition { + condition.collect_errors(errors); + } body.collect_errors(errors); } AstKind::OrPattern(OrPattern(patterns)) => { diff --git a/compiler/frontend/src/ast_to_hir.rs b/compiler/frontend/src/ast_to_hir.rs index cadb2b1a7..cdb5009fd 100644 --- a/compiler/frontend/src/ast_to_hir.rs +++ b/compiler/frontend/src/ast_to_hir.rs @@ -361,7 +361,11 @@ impl Context<'_> { let cases = cases .iter() .map(|case| match &case.kind { - AstKind::MatchCase(MatchCase { box pattern, body }) => { + AstKind::MatchCase(MatchCase { + box pattern, + condition, + body, + }) => { let (pattern, pattern_identifiers) = self.lower_pattern(pattern); let reset_state = self.start_scope(); @@ -372,10 +376,25 @@ impl Context<'_> { name.clone(), ); } + + let condition = condition.as_ref().map(|box condition| { + let condition_reset_state = self.start_scope(); + self.compile_single(condition); + self.end_scope(condition_reset_state) + }); + + let body_reset_state = self.start_scope(); self.compile(body.as_ref()); - let body = self.end_scope(reset_state); + let body = self.end_scope(body_reset_state); - (pattern, body) + let identifier_expressions = self.end_scope(reset_state); + + hir::MatchCase { + pattern, + identifier_expressions, + condition, + body, + } } AstKind::Error { errors } => { let pattern = Pattern::Error { @@ -384,9 +403,19 @@ impl Context<'_> { let reset_state = self.start_scope(); self.compile(&[]); - let body = self.end_scope(reset_state); - (pattern, body) + let body_reset_state = self.start_scope(); + self.compile(&[]); + let body = self.end_scope(body_reset_state); + + let identifier_expressions = self.end_scope(reset_state); + + hir::MatchCase { + pattern, + identifier_expressions, + condition: None, + body, + } } _ => unreachable!("Expected match case in match cases, got {case:?}."), }) diff --git a/compiler/frontend/src/cst/error.rs b/compiler/frontend/src/cst/error.rs index 73e1eb044..196c03ea7 100644 --- a/compiler/frontend/src/cst/error.rs +++ b/compiler/frontend/src/cst/error.rs @@ -8,6 +8,7 @@ pub enum CstError { ListNotClosed, MatchCaseMissesArrow, MatchCaseMissesBody, + MatchCaseMissesCondition, MatchMissesCases, OpeningParenthesisMissesExpression, OrPatternMissesRight, diff --git a/compiler/frontend/src/cst/is_multiline.rs b/compiler/frontend/src/cst/is_multiline.rs index edebe0466..ada83953c 100644 --- a/compiler/frontend/src/cst/is_multiline.rs +++ b/compiler/frontend/src/cst/is_multiline.rs @@ -107,9 +107,17 @@ impl IsMultiline for CstKind { } => expression.is_multiline() || percent.is_multiline() || cases.is_multiline(), Self::MatchCase { pattern, + condition, arrow, body, - } => pattern.is_multiline() || arrow.is_multiline() || body.is_multiline(), + } => { + pattern.is_multiline() + || condition.as_deref().map_or(false, |(comma, condition)| { + comma.is_multiline() || condition.is_multiline() + }) + || arrow.is_multiline() + || body.is_multiline() + } Self::Function { opening_curly_brace, parameters_and_arrow, diff --git a/compiler/frontend/src/cst/kind.rs b/compiler/frontend/src/cst/kind.rs index dd9e566cd..18a251a41 100644 --- a/compiler/frontend/src/cst/kind.rs +++ b/compiler/frontend/src/cst/kind.rs @@ -1,6 +1,6 @@ use super::{Cst, CstData, CstError}; use num_bigint::BigUint; -use std::fmt::{self, Display, Formatter}; +use std::fmt::{self, Display, Formatter, Pointer}; use strum_macros::EnumIs; #[derive(Clone, Debug, EnumIs, Eq, Hash, PartialEq)] @@ -104,6 +104,7 @@ pub enum CstKind { }, MatchCase { pattern: Box>, + condition: Option>, arrow: Box>, body: Vec>, }, @@ -128,6 +129,7 @@ pub enum IntRadix { Binary, Hexadecimal, } +pub type MatchCaseWithComma = Box<(Cst, Cst)>; pub type FunctionParametersAndArrow = (Vec>, Box>); impl CstKind { @@ -289,10 +291,14 @@ impl CstKind { } Self::MatchCase { pattern, + condition, arrow, body, } => { let mut children = vec![pattern.as_ref(), arrow.as_ref()]; + if let Some(box (comma, condition)) = condition { + children.extend([&comma, &condition]); + } children.extend(body); children } @@ -505,11 +511,16 @@ impl Display for CstKind { } Self::MatchCase { pattern, + condition, arrow, body, } => { pattern.fmt(f)?; arrow.fmt(f)?; + if let Some(box (comma, condition)) = condition { + comma.fmt(f)?; + condition.fmt(f)?; + } for expression in body { expression.fmt(f)?; } diff --git a/compiler/frontend/src/cst/tree_with_ids.rs b/compiler/frontend/src/cst/tree_with_ids.rs index 1edc219f0..878eab4a4 100644 --- a/compiler/frontend/src/cst/tree_with_ids.rs +++ b/compiler/frontend/src/cst/tree_with_ids.rs @@ -138,10 +138,16 @@ impl TreeWithIds for Cst { .or_else(|| cases.find(id)), CstKind::MatchCase { pattern, + condition, arrow, body, } => pattern .find(id) + .or_else(|| { + condition.as_deref().and_then(|(comma, condition)| { + comma.find(id).or_else(|| condition.find(id)) + }) + }) .or_else(|| arrow.find(id)) .or_else(|| body.find(id)), CstKind::Function { @@ -328,11 +334,19 @@ impl TreeWithIds for Cst { ), CstKind::MatchCase { pattern, + condition, arrow, body, } => ( pattern .find_by_offset(offset) + .or_else(|| { + condition.as_deref().and_then(|(comma, condition)| { + comma + .find_by_offset(offset) + .or_else(|| condition.find_by_offset(offset)) + }) + }) .or_else(|| arrow.find_by_offset(offset)) .or_else(|| body.find_by_offset(offset)), false, diff --git a/compiler/frontend/src/cst/unwrap_whitespace_and_comment.rs b/compiler/frontend/src/cst/unwrap_whitespace_and_comment.rs index 374c7a497..89742ead8 100644 --- a/compiler/frontend/src/cst/unwrap_whitespace_and_comment.rs +++ b/compiler/frontend/src/cst/unwrap_whitespace_and_comment.rs @@ -144,10 +144,17 @@ impl UnwrapWhitespaceAndComment for Cst { }, CstKind::MatchCase { pattern, + condition, arrow, body, } => CstKind::MatchCase { pattern: pattern.unwrap_whitespace_and_comment(), + condition: condition.as_deref().map(|(comma, condition)| { + Box::new(( + comma.unwrap_whitespace_and_comment(), + condition.unwrap_whitespace_and_comment(), + )) + }), arrow: arrow.unwrap_whitespace_and_comment(), body: body.unwrap_whitespace_and_comment(), }, diff --git a/compiler/frontend/src/cst_to_ast.rs b/compiler/frontend/src/cst_to_ast.rs index 63cfd885f..cda6ef8f6 100644 --- a/compiler/frontend/src/cst_to_ast.rs +++ b/compiler/frontend/src/cst_to_ast.rs @@ -584,23 +584,36 @@ impl LoweringContext { } CstKind::MatchCase { pattern, - arrow: _, + condition, + arrow, body, } => { if lowering_type != LoweringType::Expression { return self.create_ast_for_invalid_expression_in_pattern(cst); }; - + let mut errors = vec![]; let pattern = self.lower_cst(pattern, LoweringType::Pattern); - // TODO: handle error in arrow + let condition = condition + .as_ref() + .map(|box (_, condition)| self.lower_cst(condition, LoweringType::Expression)); + + if let CstKind::Error { + unparsable_input: _, + error, + } = arrow.kind + { + errors.push(self.create_error(arrow, error)); + } let body = self.lower_csts(body); - self.create_ast( - cst.data.id, + self.create_errors_or_ast( + cst, + errors, MatchCase { pattern: Box::new(pattern), + condition: condition.map(Box::new), body, }, ) diff --git a/compiler/frontend/src/error.rs b/compiler/frontend/src/error.rs index 36ee54807..295156483 100644 --- a/compiler/frontend/src/error.rs +++ b/compiler/frontend/src/error.rs @@ -63,6 +63,7 @@ impl Display for CompilerErrorPayload { CstError::MatchMissesCases => "This match misses cases to match against.", CstError::MatchCaseMissesArrow => "This match case misses an arrow.", CstError::MatchCaseMissesBody => "This match case misses a body to run.", + CstError::MatchCaseMissesCondition => "This match case condition is empty.", CstError::OpeningParenthesisMissesExpression => { "Here's an opening parenthesis without an expression after it." } diff --git a/compiler/frontend/src/hir.rs b/compiler/frontend/src/hir.rs index a0c6e7b3b..f35a8a9de 100644 --- a/compiler/frontend/src/hir.rs +++ b/compiler/frontend/src/hir.rs @@ -45,7 +45,11 @@ fn containing_body_of(db: &dyn HirDb, id: Id) -> Arc { Expression::Match { cases, .. } => { let body = cases .into_iter() - .map(|(_, body)| body) + .flat_map( + |MatchCase { + condition, body, .. + }| condition.into_iter().chain([body]), + ) .find(|body| body.expressions.contains_key(&id)) .unwrap(); Arc::new(body) @@ -90,7 +94,13 @@ impl Expression { Self::PatternIdentifierReference(_) => {} Self::Match { expression, cases } => { ids.push(expression.clone()); - for (_, body) in cases { + for MatchCase { + condition, body, .. + } in cases + { + if let Some(condition) = condition { + condition.collect_all_ids(ids); + } body.collect_all_ids(ids); } } @@ -318,7 +328,7 @@ pub enum Expression { /// Each case consists of the pattern to match against, and the body /// which starts with [PatternIdentifierReference]s for all identifiers /// in the pattern. - cases: Vec<(Pattern, Body)>, + cases: Vec, }, Function(Function), Builtin(BuiltinFunction), @@ -371,6 +381,14 @@ impl ToRichIr for PatternIdentifierId { } } +#[derive(Clone, PartialEq, Eq, Debug)] +pub struct MatchCase { + pub pattern: Pattern, + pub identifier_expressions: Body, + pub condition: Option, + pub body: Body, +} + #[derive(Clone, PartialEq, Eq, Debug)] pub enum Pattern { NewIdentifier(PatternIdentifierId), @@ -594,18 +612,31 @@ impl ToRichIr for Expression { Self::Match { expression, cases } => { expression.build_rich_ir(builder); builder.push(" %", None, EnumSet::empty()); - builder.push_children_custom_multiline(cases, |builder, (pattern, body)| { - pattern.build_rich_ir(builder); - builder.push(" ->", None, EnumSet::empty()); - builder.indent(); - builder.push_foldable(|builder| { - if !body.expressions.is_empty() { - builder.push_newline(); + builder.push_children_custom_multiline( + cases, + |builder, + MatchCase { + pattern, + condition, + body, + .. + }| { + pattern.build_rich_ir(builder); + if let Some(condition) = condition { + builder.push(", ", None, EnumSet::empty()); + condition.build_rich_ir(builder); } - body.build_rich_ir(builder); - }); - builder.dedent(); - }); + builder.push(" ->", None, EnumSet::empty()); + builder.indent(); + builder.push_foldable(|builder| { + if !body.expressions.is_empty() { + builder.push_newline(); + } + body.build_rich_ir(builder); + }); + builder.dedent(); + }, + ); } Self::Function(function) => { builder.push( @@ -771,7 +802,16 @@ impl Expression { Self::Destructure { .. } => None, Self::PatternIdentifierReference { .. } => None, // TODO: use binary search - Self::Match { cases, .. } => cases.iter().find_map(|(_, body)| body.find(id)), + Self::Match { cases, .. } => cases.iter().find_map( + |MatchCase { + condition, body, .. + }| { + condition + .as_ref() + .and_then(|condition| condition.find(id)) + .or_else(|| body.find(id)) + }, + ), Self::Function(Function { body, .. }) => body.find(id), Self::Builtin(_) => None, Self::Call { .. } => None, @@ -811,8 +851,17 @@ impl CollectErrors for Expression { | Self::Struct(_) | Self::PatternIdentifierReference { .. } => {} Self::Match { cases, .. } => { - for (pattern, body) in cases { + for MatchCase { + pattern, + identifier_expressions: _, + condition, + body, + } in cases + { pattern.collect_errors(errors); + if let Some(condition) = condition { + condition.collect_errors(errors); + } body.collect_errors(errors); } } @@ -867,3 +916,10 @@ impl CollectErrors for Body { } } } +impl CollectErrors for Option { + fn collect_errors(&self, errors: &mut Vec) { + if let Some(this) = self { + this.collect_errors(errors); + } + } +} diff --git a/compiler/frontend/src/hir_to_mir.rs b/compiler/frontend/src/hir_to_mir.rs index 6bce9346b..3d37b4ecd 100644 --- a/compiler/frontend/src/hir_to_mir.rs +++ b/compiler/frontend/src/hir_to_mir.rs @@ -113,30 +113,12 @@ fn generate_needs_function(body: &mut BodyBuilder) -> Id { // Common stuff. let needs_code = body.push_hir_id(needs_id.clone()); let builtin_equals = body.push_builtin(BuiltinFunction::Equals); - let nothing_tag = body.push_nothing(); // Make sure the condition is a bool. - let true_tag = body.push_bool(true); - let false_tag = body.push_bool(false); - let is_condition_true = - body.push_call(builtin_equals, vec![condition, true_tag], needs_code); - let is_condition_bool = body.push_if_else( - &needs_id.child("isConditionTrue"), - is_condition_true, - |body| { - body.push_reference(true_tag); - }, - |body| { - body.push_call(builtin_equals, vec![condition, false_tag], needs_code); - }, - needs_code, - ); - body.push_if_else( + let is_condition_bool = body.push_is_bool(&needs_id, condition, needs_code); + body.push_if_not( &needs_id.child("isConditionBool"), is_condition_bool, - |body| { - body.push_reference(nothing_tag); - }, |body| { let panic_reason = body.push_text("The `condition` must be either `True` or `False`.".to_string()); @@ -154,12 +136,9 @@ fn generate_needs_function(body: &mut BodyBuilder) -> Id { vec![type_of_reason, text_tag], responsible_for_call, ); - body.push_if_else( + body.push_if_not( &needs_id.child("isReasonText"), is_reason_text, - |body| { - body.push_reference(nothing_tag); - }, |body| { let panic_reason = body.push_text("The `reason` must be a text.".to_string()); body.push_panic(panic_reason, responsible_for_call); @@ -168,12 +147,9 @@ fn generate_needs_function(body: &mut BodyBuilder) -> Id { ); // The core logic of the needs. - body.push_if_else( + body.push_if_not( &needs_id.child("condition"), condition, - |body| { - body.push_reference(nothing_tag); - }, |body| { body.push_panic(reason, responsible_for_condition); }, @@ -357,14 +333,10 @@ impl<'a> LoweringContext<'a> { is_trivial: false, }); - let nothing = body.push_nothing(); let is_match = body.push_is_match(pattern_result, responsible); - body.push_if_else( + body.push_if_not( &hir_id.child("isMatch"), is_match, - |body| { - body.push_reference(nothing); - }, |body| { let list_get_function = body.push_builtin(BuiltinFunction::ListGet); let one = body.push_int(1); @@ -385,6 +357,8 @@ impl<'a> LoweringContext<'a> { self.ongoing_destructuring.clone().unwrap(); if is_trivial { + // something % + // foo -> ... body.push_reference(result) } else { let list_get = body.push_builtin(BuiltinFunction::ListGet); @@ -516,7 +490,7 @@ impl<'a> LoweringContext<'a> { hir_id: hir::Id, body: &mut BodyBuilder, expression: Id, - cases: &[(hir::Pattern, hir::Body)], + cases: &[hir::MatchCase], responsible_for_needs: Id, responsible_for_match: Id, ) -> Id { @@ -527,7 +501,6 @@ impl<'a> LoweringContext<'a> { cases, responsible_for_needs, responsible_for_match, - vec![], 0, ) } @@ -537,10 +510,9 @@ impl<'a> LoweringContext<'a> { hir_id: hir::Id, body: &mut BodyBuilder, expression: Id, - cases: &[(hir::Pattern, hir::Body)], + cases: &[hir::MatchCase], responsible_for_needs: Id, responsible_for_match: Id, - mut no_match_reasons: Vec, case_index: usize, ) -> Id { match cases { @@ -549,7 +521,12 @@ impl<'a> LoweringContext<'a> { // TODO: concat reasons body.push_panic(reason, responsible_for_match) } - [(case_pattern, case_body), rest @ ..] => { + [hir::MatchCase { + pattern: case_pattern, + identifier_expressions: case_identifiers, + condition: case_condition, + body: case_body, + }, rest @ ..] => { let pattern_result = PatternLoweringContext::compile_pattern( body, hir_id.clone(), @@ -557,28 +534,12 @@ impl<'a> LoweringContext<'a> { expression, case_pattern, ); + let builtin_if_else = body.push_builtin(BuiltinFunction::IfElse); - let is_match = body.push_is_match(pattern_result, responsible_for_match); - + let is_pattern_match = body.push_is_match(pattern_result, responsible_for_match); let case_id = hir_id.child(format!("case-{case_index}")); - let builtin_if_else = body.push_builtin(BuiltinFunction::IfElse); - let then_function = body.push_function(case_id.child("matched"), |body, _| { - self.ongoing_destructuring = Some(OngoingDestructuring { - result: pattern_result, - is_trivial: false, - }); - self.compile_expressions(body, responsible_for_needs, &case_body.expressions); - }); - let else_function = body.push_function(case_id.child("didNotMatch"), |body, _| { - let list_get_function = body.push_builtin(BuiltinFunction::ListGet); - let one = body.push_int(1); - let reason = body.push_call( - list_get_function, - vec![pattern_result, one], - responsible_for_match, - ); - no_match_reasons.push(reason); + let else_function = body.push_function(case_id.child("didNotMatch"), |body, _| { self.compile_match_rec( hir_id, body, @@ -586,18 +547,85 @@ impl<'a> LoweringContext<'a> { rest, responsible_for_needs, responsible_for_match, - no_match_reasons, case_index + 1, ); }); + + let then_function = body.push_function(case_id.child("patternMatch"), |body, _| { + self.ongoing_destructuring = Some(OngoingDestructuring { + result: pattern_result, + is_trivial: false, + }); + self.compile_expressions( + body, + responsible_for_needs, + &case_identifiers.expressions, + ); + + self.compile_match_case_body( + &case_id, + body, + case_condition, + case_body, + else_function, + responsible_for_needs, + responsible_for_match, + ); + }); + body.push_call( builtin_if_else, - vec![is_match, then_function, else_function], + vec![is_pattern_match, then_function, else_function], responsible_for_match, ) } } } + #[allow(clippy::too_many_arguments)] + fn compile_match_case_body( + &mut self, + case_id: &hir::Id, + body: &mut BodyBuilder, + case_condition: &Option, + case_body: &hir::Body, + else_function: Id, + responsible_for_needs: Id, + responsible_for_match: Id, + ) { + let builtin_if_else = body.push_builtin(BuiltinFunction::IfElse); + if let Some(condition) = case_condition { + self.compile_expressions(body, responsible_for_needs, &condition.expressions); + let condition_result = body.current_return_value(); + + let is_boolean = body.push_is_bool(case_id, condition_result, responsible_for_match); + body.push_if_not( + &case_id.child("conditionCheck"), + is_boolean, + |body| { + let reason_parts = [ + body.push_text("Match Condition expected boolean value, got `".to_string()), + body.push_to_debug_text(condition_result, responsible_for_match), + body.push_text("`".to_string()), + ]; + let reason = body.push_text_concatenate(&reason_parts, responsible_for_match); + body.push_panic(reason, responsible_for_match); + }, + responsible_for_match, + ); + + let then_function = body.push_function(case_id.child("conditionMatch"), |body, _| { + self.compile_expressions(body, responsible_for_needs, &case_body.expressions); + }); + + body.push_call( + builtin_if_else, + vec![condition_result, then_function, else_function], + responsible_for_needs, + ); + } else { + self.compile_expressions(body, responsible_for_needs, &case_body.expressions); + }; + } } struct PatternLoweringContext { @@ -791,8 +819,7 @@ impl PatternLoweringContext { struct_as_text, body.push_text("`.".to_string()), ]; - let reason_text = - self.push_text_concatenate(body, reason_parts); + let reason_text = body.push_text_concatenate(&reason_parts, self.responsible); self.push_no_match(body, reason_text); }, self.responsible, @@ -934,7 +961,7 @@ impl PatternLoweringContext { body.push_call(to_debug_text, vec![expected], self.responsible); let actual_as_text = body.push_call(to_debug_text, vec![actual], self.responsible); let reason_parts = reason_factory(body, expected_as_text, actual_as_text); - let reason = self.push_text_concatenate(body, reason_parts); + let reason = body.push_text_concatenate(&reason_parts, self.responsible); self.push_no_match(body, reason); }, self.responsible, @@ -1010,22 +1037,6 @@ impl PatternLoweringContext { ) } - fn push_text_concatenate(&self, body: &mut BodyBuilder, parts: Vec) -> Id { - assert!(!parts.is_empty()); - - let builtin_text_concatenate = body.push_builtin(BuiltinFunction::TextConcatenate); - parts - .into_iter() - .reduce(|left, right| { - body.push_call( - builtin_text_concatenate, - vec![left, right], - self.responsible, - ) - }) - .unwrap() - } - fn push_match(&self, body: &mut BodyBuilder, mut captured_identifiers: Vec) -> Id { captured_identifiers.insert(0, self.match_tag); body.push_list(captured_identifiers) diff --git a/compiler/frontend/src/mir/body.rs b/compiler/frontend/src/mir/body.rs index c9956d675..39fa23655 100644 --- a/compiler/frontend/src/mir/body.rs +++ b/compiler/frontend/src/mir/body.rs @@ -411,7 +411,53 @@ impl BodyBuilder { responsible, ) } + pub fn push_if_not( + &mut self, + hir_id: &hir::Id, + condition: Id, + else_builder: E, + responsible: Id, + ) -> Id + where + E: FnOnce(&mut Self), + { + self.push_if_else( + hir_id, + condition, + |body| { + body.push_nothing(); + }, + else_builder, + responsible, + ) + } + pub fn push_is_bool(&mut self, hir_id: &hir::Id, value: Id, responsible: Id) -> Id { + let is_condition_true = self.push_equals_value(value, true, responsible); + self.push_if_else( + &hir_id.child("isValueTrue"), + is_condition_true, + |body| { + body.push_reference(is_condition_true); + }, + |body| { + body.push_equals_value(value, false, responsible); + }, + responsible, + ) + } + pub fn push_equals(&mut self, a: Id, b: Id, responsible: Id) -> Id { + let builtin_equals = self.push_builtin(BuiltinFunction::Equals); + self.push_call(builtin_equals, vec![a, b], responsible) + } + pub fn push_equals_value(&mut self, a: Id, b: impl Into, responsible: Id) -> Id { + let b = self.push(b.into()); + self.push_equals(a, b, responsible) + } + pub fn push_to_debug_text(&mut self, value: Id, responsible: Id) -> Id { + let builtin_to_debug_text = self.push_builtin(BuiltinFunction::ToDebugText); + self.push_call(builtin_to_debug_text, vec![value], responsible) + } pub fn push_panic(&mut self, reason: Id, responsible: Id) -> Id { self.push(Expression::Panic { reason, @@ -438,6 +484,19 @@ impl BodyBuilder { ) } + pub fn push_text_concatenate(&mut self, parts: &[Id], responsible: Id) -> Id { + assert!(!parts.is_empty()); + + let builtin_text_concatenate = self.push_builtin(BuiltinFunction::TextConcatenate); + parts + .iter() + .copied() + .reduce(|left, right| { + self.push_call(builtin_text_concatenate, vec![left, right], responsible) + }) + .unwrap() + } + #[must_use] pub fn current_return_value(&self) -> Id { self.body.return_value() diff --git a/compiler/frontend/src/rcst_to_cst.rs b/compiler/frontend/src/rcst_to_cst.rs index 8d493b614..86af35a32 100644 --- a/compiler/frontend/src/rcst_to_cst.rs +++ b/compiler/frontend/src/rcst_to_cst.rs @@ -257,10 +257,12 @@ impl Rcst { }, CstKind::MatchCase { pattern, + condition, arrow, body, } => CstKind::MatchCase { pattern: Box::new(pattern.to_cst(state)), + condition: condition.as_ref().map(|v| Box::new(v.to_cst(state))), arrow: Box::new(arrow.to_cst(state)), body: body.to_csts_helper(state), }, @@ -334,3 +336,10 @@ impl RcstsToCstsHelperExt for Vec { csts } } + +#[extension_trait] +impl ConvertToCst for (Rcst, Rcst) { + fn to_cst(&self, state: &mut State) -> (Cst, Cst) { + (self.0.to_cst(state), self.1.to_cst(state)) + } +} diff --git a/compiler/frontend/src/string_to_rcst/expression.rs b/compiler/frontend/src/string_to_rcst/expression.rs index ed29d3b0f..b85aa4625 100644 --- a/compiler/frontend/src/string_to_rcst/expression.rs +++ b/compiler/frontend/src/string_to_rcst/expression.rs @@ -5,7 +5,7 @@ use super::{ list::list, literal::{ arrow, bar, closing_bracket, closing_curly_brace, closing_parenthesis, colon_equals_sign, - dot, equals_sign, percent, + comma, dot, equals_sign, percent, }, struct_::struct_, text::text, @@ -405,6 +405,31 @@ fn match_case(input: &str, indentation: usize) -> Option<(&str, Rcst)> { let (input, whitespace) = whitespaces_and_newlines(input, indentation, true); let pattern = pattern.wrap_in_whitespace(whitespace); + let (input, condition) = if let Some((input, condition_comma)) = comma(input) { + let (input, whitespace) = whitespaces_and_newlines(input, indentation, true); + let condition_comma = condition_comma.wrap_in_whitespace(whitespace); + if let Some((input, condition_expresion)) = expression( + input, + indentation, + ExpressionParsingOptions { + allow_assignment: false, + allow_call: true, + allow_bar: true, + allow_function: true, + }, + ) { + (input, Some((condition_comma, condition_expresion))) + } else { + let error = CstKind::Error { + unparsable_input: String::new(), + error: CstError::MatchCaseMissesCondition, + }; + (input, Some((condition_comma, error.into()))) + } + } else { + (input, None) + }; + let (input, arrow) = if let Some((input, arrow)) = arrow(input) { let (input, whitespace) = whitespaces_and_newlines(input, indentation, true); (input, arrow.wrap_in_whitespace(whitespace)) @@ -429,6 +454,7 @@ fn match_case(input: &str, indentation: usize) -> Option<(&str, Rcst)> { let case = CstKind::MatchCase { pattern: Box::new(pattern), + condition: condition.map(Box::new), arrow: Box::new(arrow), body, }; @@ -606,6 +632,7 @@ mod test { ])), cases: vec![CstKind::MatchCase { pattern: Box::new(build_simple_int(123).with_trailing_space()), + condition: None, arrow: Box::new(CstKind::Arrow.with_trailing_space()), body: vec![build_simple_int(123)], } @@ -1010,6 +1037,7 @@ mod test { ])), cases: vec![CstKind::MatchCase { pattern: Box::new(build_simple_int(1).with_trailing_space()), + condition: None, arrow: Box::new(CstKind::Arrow.with_trailing_space()), body: vec![build_simple_int(2)], } diff --git a/compiler/language_server/src/features_candy/folding_ranges.rs b/compiler/language_server/src/features_candy/folding_ranges.rs index 1eceda8e1..468f3dd0e 100644 --- a/compiler/language_server/src/features_candy/folding_ranges.rs +++ b/compiler/language_server/src/features_candy/folding_ranges.rs @@ -137,11 +137,16 @@ where } CstKind::MatchCase { pattern, + condition, arrow, body, } => { self.visit_cst(pattern); + if let Some(box (_, condition)) = condition { + self.visit_cst(condition); + } + let arrow = arrow.unwrap_whitespace_and_comment(); let body_end = body .unwrap_whitespace_and_comment() diff --git a/compiler/language_server/src/features_candy/references.rs b/compiler/language_server/src/features_candy/references.rs index 965367af8..3e092a489 100644 --- a/compiler/language_server/src/features_candy/references.rs +++ b/compiler/language_server/src/features_candy/references.rs @@ -2,7 +2,7 @@ use crate::{features::Reference, utils::LspPositionConversion}; use candy_frontend::{ ast_to_hir::AstToHir, cst::{CstDb, CstKind}, - hir::{self, Body, Expression, Function, HirDb}, + hir::{self, Body, Expression, Function, HirDb, MatchCase}, module::{Module, ModuleDb}, position::{Offset, PositionConversionDb}, }; @@ -167,7 +167,10 @@ where | Expression::Destructure { .. } | Expression::PatternIdentifierReference (_) => {}, Expression::Match { cases, .. } => { - for (_, body) in cases { + for MatchCase{condition, body, ..} in cases { + if let Some(condition) = condition { + self.visit_body(condition); + } self.visit_body(body); } }, diff --git a/compiler/language_server/src/features_candy/semantic_tokens.rs b/compiler/language_server/src/features_candy/semantic_tokens.rs index 77a17cf3b..528f22397 100644 --- a/compiler/language_server/src/features_candy/semantic_tokens.rs +++ b/compiler/language_server/src/features_candy/semantic_tokens.rs @@ -236,10 +236,15 @@ fn visit_cst( } CstKind::MatchCase { pattern, + condition, arrow, body, } => { visit_cst(builder, pattern, None); + if let Some(box (comma, condition)) = condition { + visit_cst(builder, comma, None); + visit_cst(builder, condition, None); + } visit_cst(builder, arrow, None); visit_csts(builder, body, None); } diff --git a/packages/Examples/match.candy b/packages/Examples/match.candy index 5c79b08d2..a6af975be 100644 --- a/packages/Examples/match.candy +++ b/packages/Examples/match.candy @@ -1,10 +1,18 @@ [ifElse, int] = use "Core" -foo value = +buildEnum value = needs (int.is value) - ifElse (value | int.isLessThan 5) { Ok value } { Error "NOPE" } + ifElse (value | int.isLessThan 10) { Ok value } { Error "NOPE" } -main = foo 2 % - Ok value, value | int.isGreaterThan 5 -> value - Error value -> 10 - _ -> 20 +testFunction value = + needs (int.is value) + buildEnum value % + Ok value, value | int.isLessThan 2 -> value + Ok value, value | int.isGreaterThan 3 -> int.multiply value 2 + Error value -> 10 + _ -> 20 + + +main := { args -> + (testFunction 1, testFunction 2, testFunction 3, testFunction 4, testFunction 40) +}