diff --git a/compiler_v4/src/ast_to_hir.rs b/compiler_v4/src/ast_to_hir.rs index c422052db..be765bf64 100644 --- a/compiler_v4/src/ast_to_hir.rs +++ b/compiler_v4/src/ast_to_hir.rs @@ -8,9 +8,9 @@ use crate::{ error::CompilerError, hir::{ self, Assignment, Body, BodyOrBuiltin, BuiltinFunction, ContainsError, Expression, - ExpressionKind, Function, FunctionSignature, Hir, Id, Impl, NamedType, Parameter, - ParameterType, SliceOfTypeParameter, StructField, SwitchCase, Trait, TraitDefinition, - TraitFunction, Type, TypeDeclaration, TypeDeclarationKind, TypeParameter, + ExpressionKind, Function, FunctionSignature, FunctionType, Hir, Id, Impl, NamedType, + Parameter, ParameterType, SliceOfTypeParameter, StructField, SwitchCase, Trait, + TraitDefinition, TraitFunction, Type, TypeDeclaration, TypeDeclarationKind, TypeParameter, }, id::IdGenerator, position::Offset, @@ -1023,7 +1023,7 @@ impl<'a> Context<'a> { } fn lower_assignment(&mut self, id: Id) { let declaration = self.assignments.get(&id).unwrap(); - let value = declaration.ast.value.clone(); + let value = &declaration.ast.value; let type_ = declaration.type_.clone(); let graph_index = declaration.graph_index; @@ -1232,7 +1232,7 @@ impl<'c, 'a> BodyBuilder<'c, 'a> { context: &'c mut Context<'a>, type_parameters: &'c [TypeParameter], self_base_type: Option<&'c Type>, - fun: impl FnOnce(&mut BodyBuilder), + fun: impl FnOnce(&mut Self), ) -> (Body, FxHashSet) { let mut builder = Self { context, @@ -1246,8 +1246,12 @@ impl<'c, 'a> BodyBuilder<'c, 'a> { (builder.body, builder.global_assignment_dependencies) } #[must_use] - fn build_inner(&mut self, fun: impl FnOnce(&mut BodyBuilder)) -> Body { - BodyBuilder::build( + fn build_inner<'s, 'cc>(&'s mut self, fun: impl FnOnce(&mut BodyBuilder<'cc, 'a>)) -> Body + where + 'c: 'cc, + 's: 'cc, + { + BodyBuilder::<'cc, 'a>::build( self.context, self.type_parameters, self.self_base_type, @@ -1263,7 +1267,7 @@ impl<'c, 'a> BodyBuilder<'c, 'a> { fn lower_statements( &mut self, - statements: &[AstStatement], + statements: &'a [AstStatement], context_type: Option<&Type>, ) -> (Id, Type) { let mut last_expression = None; @@ -1314,7 +1318,7 @@ impl<'c, 'a> BodyBuilder<'c, 'a> { fn lower_expression( &mut self, - expression: &AstExpression, + expression: &'a AstExpression, context_type: Option<&Type>, ) -> (Id, Type) { match self.lower_expression_raw(expression, context_type) { @@ -1342,7 +1346,7 @@ impl<'c, 'a> BodyBuilder<'c, 'a> { } fn lower_expression_raw( &mut self, - expression: &AstExpression, + expression: &'a AstExpression, context_type: Option<&Type>, ) -> LoweredExpression { match &expression.kind { @@ -1427,7 +1431,7 @@ impl<'c, 'a> BodyBuilder<'c, 'a> { match receiver { LoweredExpression::Expression { id, type_ } => { // bar.foo(baz) - let arguments = Self::lower_arguments(self, &call.arguments); + let arguments = self.lower_arguments(&call.arguments); let arguments = iter::once((id, type_)) .chain(arguments.into_vec()) .collect_vec(); @@ -1470,7 +1474,7 @@ impl<'c, 'a> BodyBuilder<'c, 'a> { if identifier.string.chars().next().unwrap().is_lowercase() { // foo(bar, baz) - let arguments = Self::lower_arguments(self, &call.arguments); + let arguments = self.lower_arguments(&call.arguments); return self.lower_call( identifier, type_arguments.as_deref(), @@ -1568,6 +1572,15 @@ impl<'c, 'a> BodyBuilder<'c, 'a> { LoweredExpression::Error } Type::Self_ { .. } => todo!(), + Type::Function(type_) => { + self.context.add_error( + key.span.clone(), + format!( + "Navigation on value of function type `{type_}` is not supported yet." + ), + ); + LoweredExpression::Error + } Type::Error => todo!(), }, LoweredExpression::NamedTypeReference(type_) => { @@ -1618,8 +1631,36 @@ impl<'c, 'a> BodyBuilder<'c, 'a> { LoweredExpression::Error => LoweredExpression::Error, } } - AstExpressionKind::Lambda(AstLambda { .. }) => { - todo!() + AstExpressionKind::Lambda(AstLambda { parameters, body }) => { + let parameters = self.context.lower_parameters( + self.type_parameters, + self.self_base_type, + parameters, + ); + + let body = self.build_inner(|builder| { + for parameter in parameters.iter() { + builder.push_parameter(parameter.clone()); + } + + // TODO: pass context_type + builder.lower_statements( + &body.statements, + if let Some(Type::Function(FunctionType { + box return_type, .. + })) = context_type + { + Some(return_type) + } else { + None + }, + ); + }); + let type_ = FunctionType::new( + parameters.iter().map(|it| it.type_.clone()).collect_vec(), + body.return_type().clone(), + ); + self.push_lowered(None, ExpressionKind::Lambda { parameters, body }, type_) } AstExpressionKind::Body(AstBody { statements, .. }) => { let (id, type_) = self.lower_statements(statements, context_type); @@ -1664,11 +1705,21 @@ impl<'c, 'a> BodyBuilder<'c, 'a> { Type::Parameter(type_) => { self.context.add_error( expression.span.clone(), - format!("Can't switch over type parameter `{}`", type_.name), + format!( + "Can't switch over value of type parameter type `{}`", + type_.name + ), ); return LoweredExpression::Error; } Type::Self_ { .. } => todo!(), + Type::Function(type_) => { + self.context.add_error( + expression.span.clone(), + format!("Can't switch over value of function type `{type_}`"), + ); + return LoweredExpression::Error; + } Type::Error => return LoweredExpression::Error, }; @@ -1795,7 +1846,75 @@ impl<'c, 'a> BodyBuilder<'c, 'a> { type_arguments: Option<&[Type]>, arguments: &[(Id, Type)], ) -> LoweredExpression { - // TODO(lambdas): resolve local identifiers as well if not calling using instance syntax + let argument_types = arguments + .iter() + .map(|(_, type_)| type_.clone()) + .collect::>(); + + if let Some((_, id, type_)) = self + .local_identifiers + .iter() + .find(|(it, _, _)| it == &name.string) + { + // Local lambda call + let Type::Function(type_) = type_ else { + self.context.add_error( + name.span.clone(), + format!("`{}` is not a function", name.string), + ); + return LoweredExpression::Error; + }; + let id = *id; + let type_ = type_.clone(); + + let result = self.match_signature( + None, + &[], + &type_.parameter_types, + type_arguments, + &argument_types, + ); + return match result { + Ok(substitutions) => { + assert!(substitutions.is_empty()); + self.push_lowered( + name.string.clone(), + ExpressionKind::Call { + function: id, + substitutions, + arguments: arguments.iter().map(|(id, _)| *id).collect(), + }, + *type_.return_type, + ) + } + Err(error) => { + self.context.add_error( + name.span.clone(), + match error { + CallLikeLoweringError::TypeArgumentCount => { + "Wrong number of type arguments".to_string() + } + CallLikeLoweringError::ArgumentCount => { + "Wrong number of arguments".to_string() + } + CallLikeLoweringError::Unification(Some(error)) => error.into_string(), + CallLikeLoweringError::Unification(None) => { + "Mismatching types".to_string() + } + CallLikeLoweringError::FunctionReachableViaMultipleImpls => { + "Function is reachable via multiple impls".to_string() + } + // TODO: more specific error message + CallLikeLoweringError::TypeArgumentMismatch => { + "Type arguments are not assignable".to_string() + } + }, + ); + LoweredExpression::Error + } + }; + } + let matches = self .context .get_all_functions_matching_name(&name.string) @@ -1814,11 +1933,6 @@ impl<'c, 'a> BodyBuilder<'c, 'a> { return LoweredExpression::Error; } - let argument_types = arguments - .iter() - .map(|(_, type_)| type_.clone()) - .collect::>(); - if argument_types.iter().any(ContainsError::contains_error) { return LoweredExpression::Error; } @@ -1913,7 +2027,7 @@ impl<'c, 'a> BodyBuilder<'c, 'a> { fn lower_struct_creation( &mut self, span: Range, - call: &AstCall, + call: &'a AstCall, type_arguments: Option<&[Type]>, type_: &str, type_parameters: &[TypeParameter], @@ -1927,7 +2041,7 @@ impl<'c, 'a> BodyBuilder<'c, 'a> { return LoweredExpression::Error; }; - let arguments = Self::lower_arguments(self, &call.arguments); + let arguments = self.lower_arguments(&call.arguments); let result = self.match_signature( None, @@ -1984,7 +2098,7 @@ impl<'c, 'a> BodyBuilder<'c, 'a> { } fn lower_enum_creation( &mut self, - call: &AstCall, + call: &'a AstCall, type_arguments: Option<&[Type]>, type_: &str, variant: &AstString, @@ -2007,7 +2121,7 @@ impl<'c, 'a> BodyBuilder<'c, 'a> { .as_ref() .map(|variant_type| vec![variant_type.clone()].into_boxed_slice()) .unwrap_or_default(); - let arguments = Self::lower_arguments(self, &call.arguments); + let arguments = self.lower_arguments(&call.arguments); let result = self.match_signature( None, @@ -2063,14 +2177,11 @@ impl<'c, 'a> BodyBuilder<'c, 'a> { enum_type, ) } - fn lower_arguments( - builder: &mut BodyBuilder, - arguments: &AstResult, - ) -> Box<[(Id, Type)]> { + fn lower_arguments(&mut self, arguments: &'a AstResult) -> Box<[(Id, Type)]> { arguments .arguments_or_default() .iter() - .map(|argument| builder.lower_expression(&argument.value, None)) + .map(|argument| self.lower_expression(&argument.value, None)) .collect::>() } fn match_signature( @@ -2333,6 +2444,23 @@ impl<'h> TypeUnifier<'h> { } (Type::Parameter { .. }, Type::Named { .. }) => Ok(true), (Type::Self_ { base_type }, _) => self.unify(base_type, parameter), + (Type::Function(argument), Type::Function(parameter)) => { + if argument.parameter_types.len() != parameter.parameter_types.len() { + return Ok(false); + } + for (argument, parameter) in argument + .parameter_types + .iter() + .zip_eq(parameter.parameter_types.iter()) + { + if !self.unify(argument, parameter)? { + return Ok(false); + } + } + + self.unify(&argument.return_type, ¶meter.return_type) + } + (Type::Function(_), _) | (_, Type::Function(_)) => Ok(false), (_, Type::Self_ { base_type: _ }) => { self.unify(argument, &ParameterType::self_type().into()) } diff --git a/compiler_v4/src/hir.rs b/compiler_v4/src/hir.rs index 51ea4095b..f5c7d34e6 100644 --- a/compiler_v4/src/hir.rs +++ b/compiler_v4/src/hir.rs @@ -316,6 +316,8 @@ pub enum Type { Self_ { base_type: Box, }, + #[from] + Function(FunctionType), Error, } #[derive(Clone, Debug, Eq, Hash, PartialEq)] @@ -394,6 +396,30 @@ impl Display for ParameterType { write!(f, "{}", self.name) } } +#[derive(Clone, Debug, Eq, Hash, PartialEq)] +pub struct FunctionType { + pub parameter_types: Box<[Type]>, + pub return_type: Box, +} +impl FunctionType { + #[must_use] + pub fn new(parameter_types: impl Into>, return_type: impl Into) -> Self { + Self { + parameter_types: parameter_types.into(), + return_type: Box::new(return_type.into()), + } + } +} +impl Display for FunctionType { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + write!( + f, + "({}) {}", + self.parameter_types.iter().join(", "), + self.return_type, + ) + } +} impl Type { pub fn nothing() -> Self { NamedType::nothing().into() @@ -419,6 +445,10 @@ impl Type { }), Self::Parameter (type_) => environment.get(type_).unwrap_or_else(|| panic!("Missing substitution for type parameter {type_} (environment: {environment:?})")).clone(), Self::Self_ { base_type } => environment.get(&ParameterType::self_type()).cloned().unwrap_or_else(|| Self::Self_ { base_type: base_type.clone() }), + Self::Function(FunctionType{parameter_types, return_type }) => Self::Function(FunctionType::new( + parameter_types.iter().map(|it| it.substitute(environment)).collect_vec(), + return_type.substitute(environment), + )), Self::Error => Self::Error, } } @@ -449,6 +479,7 @@ impl Display for Type { Self::Named(type_) => write!(f, "{type_}"), Self::Parameter(ParameterType { name }) => write!(f, "{name}"), Self::Self_ { base_type } => write!(f, "Self<{base_type}>"), + Self::Function(type_) => write!(f, "{type_}"), Self::Error => write!(f, ""), } } @@ -472,6 +503,7 @@ impl ContainsError for Type { Self::Named(named_type) => named_type.contains_error(), Self::Parameter(_) => false, Self::Self_ { base_type } => base_type.contains_error(), + Self::Function(function_type) => function_type.contains_error(), Self::Error => true, } } @@ -483,6 +515,14 @@ impl ContainsError for NamedType { .any(ContainsError::contains_error) } } +impl ContainsError for FunctionType { + fn contains_error(&self) -> bool { + self.parameter_types + .iter() + .any(ContainsError::contains_error) + || self.return_type.contains_error() + } +} #[derive(Clone, Debug, Eq, PartialEq)] pub struct Assignment { @@ -549,6 +589,12 @@ pub struct Parameter { pub name: Box, pub type_: Type, } +impl ToText for Parameter { + fn build_text(&self, builder: &mut TextBuilder) { + self.id.build_text(builder); + builder.push(format!(": {} = {}", self.type_, self.name)); + } +} #[derive(Clone, Debug, Eq, PartialEq)] pub enum BodyOrBuiltin { @@ -565,6 +611,10 @@ impl Body { pub fn return_value_id(&self) -> Id { self.expressions.last().unwrap().0 } + #[must_use] + pub fn return_type(&self) -> &Type { + &self.expressions.last().unwrap().2.type_ + } } impl ToText for Body { fn build_text(&self, builder: &mut TextBuilder) { @@ -632,6 +682,10 @@ pub enum ExpressionKind { enum_: Type, cases: Box<[SwitchCase]>, }, + Lambda { + parameters: Box<[Parameter]>, + body: Body, + }, Error, } impl ToText for ExpressionKind { @@ -708,6 +762,12 @@ impl ToText for ExpressionKind { }); builder.push("}"); } + Self::Lambda { parameters, body } => { + builder.push("("); + builder.push_children(parameters.iter(), ", "); + builder.push(") => "); + body.build_text(builder); + } Self::Error => builder.push(""), } } diff --git a/compiler_v4/src/hir_to_mono.rs b/compiler_v4/src/hir_to_mono.rs index 85f4b8f10..723ab6c61 100644 --- a/compiler_v4/src/hir_to_mono.rs +++ b/compiler_v4/src/hir_to_mono.rs @@ -220,7 +220,10 @@ impl<'h> Context<'h> { .solve(&subgoal.substitute_all(&solver_substitutions), &[]); match solution { SolverSolution::Unique(solution) => Some(solution.used_rule), - SolverSolution::Ambiguous => panic!(), + SolverSolution::Ambiguous => panic!( + "Ambiguous solver solution for {}", + subgoal.substitute_all(&solver_substitutions), + ), SolverSolution::Impossible => None, } }) @@ -324,6 +327,19 @@ impl<'h> Context<'h> { hir::Type::Self_ { base_type } => { panic!("Self type (base type: {base_type}) should have been monomorphized.") } + hir::Type::Function(hir::FunctionType { + parameter_types, + return_type, + }) => { + entry.insert(None); + mono::TypeDeclaration::Function { + parameter_types: parameter_types + .iter() + .map(|it| self.lower_type(it)) + .collect(), + return_type: self.lower_type(return_type), + } + } hir::Type::Error => unreachable!(), }; *self.type_declarations.get_mut(&mangled_name).unwrap() = Some(declaration); @@ -353,6 +369,19 @@ impl<'h> Context<'h> { hir::Type::Self_ { base_type } => { panic!("Self type (base type: {base_type}) should have been monomorphized.") } + hir::Type::Function(type_) => { + result.push_str("$Fun$"); + if !type_.parameter_types.is_empty() { + result.push_str("of$"); + for type_ in type_.parameter_types.iter() { + Self::mangle_type_helper(result, type_); + result.push('$'); + } + result.push_str("end$"); + } + result.push_str("returns$"); + Self::mangle_type_helper(result, &type_.return_type); + } hir::Type::Error => result.push_str("Never"), } } @@ -384,7 +413,10 @@ impl<'c, 'h> BodyBuilder<'c, 'h> { (builder.parameters.into_boxed_slice(), builder.body) } #[must_use] - fn build_inner(&mut self, fun: impl FnOnce(&mut BodyBuilder)) -> mono::Body { + fn build_inner( + &mut self, + fun: impl FnOnce(&mut BodyBuilder), + ) -> (Box<[mono::Parameter]>, mono::Body) { let mut builder = BodyBuilder { context: self.context, environment: self.environment, @@ -397,10 +429,9 @@ impl<'c, 'h> BodyBuilder<'c, 'h> { builder.id_generator = mem::take(&mut self.id_generator); fun(&mut builder); - assert!(builder.parameters.is_empty()); self.id_generator = builder.id_generator; - builder.body + (builder.parameters.into_boxed_slice(), builder.body) } fn add_parameters(&mut self, parameters: &[hir::Parameter]) { @@ -501,19 +532,22 @@ impl<'c, 'h> BodyBuilder<'c, 'h> { substitutions, arguments, } => { - let function = self - .context - .lower_function(*function, &self.merge_substitutions(substitutions)); let arguments = self.lower_ids(arguments); - self.push( - id, - name, - mono::ExpressionKind::Call { + let expression_kind = if let Some(id) = self.id_mapping.get(function) { + let lambda = *self.id_mapping.get(function).unwrap_or_else(|| { + panic!("Unknown function: {function} (referenced from {id} ({name:?}))") + }); + mono::ExpressionKind::CallLambda { lambda, arguments } + } else { + let function = self + .context + .lower_function(*function, &self.merge_substitutions(substitutions)); + mono::ExpressionKind::CallFunction { function, arguments, - }, - &expression.type_, - ); + } + }; + self.push(id, name, expression_kind, &expression.type_); } hir::ExpressionKind::Switch { value, @@ -531,12 +565,14 @@ impl<'c, 'h> BodyBuilder<'c, 'h> { mono::SwitchCase { variant: case.variant.clone(), value_id: value_ids.map(|(_, mir_id)| mir_id), - body: BodyBuilder::build_inner(self, |builder| { - if let Some((hir_id, mir_id)) = value_ids { - builder.id_mapping.force_insert(hir_id, mir_id); - } - builder.lower_expressions(&case.body.expressions); - }), + body: self + .build_inner(|builder| { + if let Some((hir_id, mir_id)) = value_ids { + builder.id_mapping.force_insert(hir_id, mir_id); + } + builder.lower_expressions(&case.body.expressions); + }) + .1, } }) .collect(); @@ -551,6 +587,18 @@ impl<'c, 'h> BodyBuilder<'c, 'h> { &expression.type_, ); } + hir::ExpressionKind::Lambda { parameters, body } => { + let (parameters, body) = self.build_inner(|builder| { + builder.add_parameters(parameters); + builder.lower_expressions(&body.expressions); + }); + self.push( + id, + name, + mono::ExpressionKind::Lambda(mono::Lambda { parameters, body }), + &expression.type_, + ); + } hir::ExpressionKind::Error => todo!(), } } @@ -592,6 +640,14 @@ impl<'c, 'h> BodyBuilder<'c, 'h> { .collect(), }), hir::Type::Parameter(parameter_type) => self.environment[parameter_type].clone(), + hir::Type::Function(function_type) => hir::Type::Function(hir::FunctionType { + parameter_types: function_type + .parameter_types + .iter() + .map(|it| self.merge_substitution(it)) + .collect(), + return_type: Box::new(self.merge_substitution(&function_type.return_type)), + }), hir::Type::Self_ { .. } | hir::Type::Error => unreachable!(), } } diff --git a/compiler_v4/src/mono.rs b/compiler_v4/src/mono.rs index 7145faffb..dd2f493bc 100644 --- a/compiler_v4/src/mono.rs +++ b/compiler_v4/src/mono.rs @@ -1,6 +1,6 @@ use crate::{hir::BuiltinFunction, impl_countable_id}; use derive_more::Deref; -use rustc_hash::FxHashMap; +use rustc_hash::{FxHashMap, FxHashSet}; use std::fmt::{self, Display, Formatter}; #[derive(Clone, Copy, Debug, Default, Deref, Eq, Hash, Ord, PartialEq, PartialOrd)] @@ -33,6 +33,10 @@ pub enum TypeDeclaration { Enum { variants: Box<[EnumVariant]>, }, + Function { + parameter_types: Box<[Box]>, + return_type: Box, + }, } #[derive(Clone, Debug, Eq, Hash, PartialEq)] pub struct EnumVariant { @@ -72,9 +76,72 @@ pub struct Body { pub expressions: Vec<(Id, Option>, Expression)>, } impl Body { + #[must_use] pub fn return_value_id(&self) -> Id { self.expressions.last().unwrap().0 } + #[must_use] + pub fn return_type(&self) -> &str { + &self.expressions.last().unwrap().2.type_ + } + + fn collect_defined_and_referenced_ids( + &self, + defined_ids: &mut FxHashSet, + referenced_ids: &mut FxHashSet, + ) { + for (id, _, expression) in &self.expressions { + defined_ids.insert(*id); + match &expression.kind { + ExpressionKind::Int(_) | ExpressionKind::Text(_) => {} + ExpressionKind::CreateStruct { fields, .. } => { + defined_ids.extend(fields.iter()); + } + ExpressionKind::StructAccess { struct_, .. } => { + referenced_ids.insert(*struct_); + } + ExpressionKind::CreateEnum { value, .. } => { + referenced_ids.extend(value.iter()); + } + ExpressionKind::GlobalAssignmentReference(_) => {} + ExpressionKind::LocalReference(referenced_id) => { + referenced_ids.insert(*referenced_id); + } + ExpressionKind::CallFunction { arguments, .. } => { + referenced_ids.extend(arguments.iter()); + } + ExpressionKind::CallLambda { + lambda, arguments, .. + } => { + referenced_ids.insert(*lambda); + referenced_ids.extend(arguments.iter()); + } + ExpressionKind::Switch { value, .. } => { + referenced_ids.insert(*value); + } + ExpressionKind::Lambda(Lambda { parameters, body }) => { + defined_ids.extend(parameters.iter().map(|it| it.id)); + body.collect_defined_and_referenced_ids(defined_ids, referenced_ids); + } + } + } + } + #[must_use] + pub fn find_expression(&self, id: Id) -> Option<&Expression> { + self.expressions.iter().find_map(|(it_id, _, expression)| { + if *it_id == id { + return Some(expression); + } + + match &expression.kind { + ExpressionKind::Switch { cases, .. } => { + cases.iter().find_map(|it| it.body.find_expression(id)) + } + ExpressionKind::Lambda(Lambda { body, .. }) => body.find_expression(id), + _ => None, + } + }) + } } #[derive(Clone, Debug, Eq, Hash, PartialEq)] @@ -102,19 +169,49 @@ pub enum ExpressionKind { }, GlobalAssignmentReference(Box), LocalReference(Id), - Call { + CallFunction { function: Box, arguments: Box<[Id]>, }, + CallLambda { + lambda: Id, + arguments: Box<[Id]>, + }, Switch { value: Id, enum_: Box, cases: Box<[SwitchCase]>, }, + Lambda(Lambda), } + #[derive(Clone, Debug, Eq, Hash, PartialEq)] pub struct SwitchCase { pub variant: Box, pub value_id: Option, pub body: Body, } + +#[derive(Clone, Debug, Eq, Hash, PartialEq)] +pub struct Lambda { + pub parameters: Box<[Parameter]>, + pub body: Body, +} +impl Lambda { + #[must_use] + pub fn closure_with_types(&self, function_body: &Body) -> FxHashMap> { + self.closure() + .into_iter() + .map(|id| (id, function_body.find_expression(id).unwrap().type_.clone())) + .collect() + } + #[must_use] + pub fn closure(&self) -> FxHashSet { + let mut defined_ids = self.parameters.iter().map(|it| it.id).collect(); + let mut referenced_ids = FxHashSet::default(); + self.body + .collect_defined_and_referenced_ids(&mut defined_ids, &mut referenced_ids); + referenced_ids.retain(|id| !defined_ids.contains(id)); + referenced_ids + } +} diff --git a/compiler_v4/src/mono_to_c.rs b/compiler_v4/src/mono_to_c.rs index 17aa46a5f..6355aa3b0 100644 --- a/compiler_v4/src/mono_to_c.rs +++ b/compiler_v4/src/mono_to_c.rs @@ -1,6 +1,9 @@ use crate::{ hir::BuiltinFunction, - mono::{Body, BodyOrBuiltin, Expression, ExpressionKind, Function, Id, Mono, TypeDeclaration}, + mono::{ + Body, BodyOrBuiltin, Expression, ExpressionKind, Function, Id, Lambda, Mono, + TypeDeclaration, + }, }; use itertools::Itertools; @@ -117,19 +120,34 @@ impl<'h> Context<'h> { } self.push("} value;\n};\n"); } + TypeDeclaration::Function { + parameter_types, + return_type, + } => { + self.push("void* closure;\n"); + self.push(format!("{return_type}* (*function)(void*")); + for parameter_type in parameter_types.iter() { + self.push(format!(", {parameter_type}*")); + } + self.push(");\n"); + self.push("};\n"); + } } } } fn lower_assignment_declarations(&mut self) { for (name, assignment) in &self.mono.assignments { + self.lower_lambda_declarations_in(name, &assignment.body); self.push(format!("{}* {name};\n", &assignment.type_)); } } fn lower_assignment_definitions(&mut self) { for (name, assignment) in &self.mono.assignments { + self.lower_lambda_definitions_in(name, &assignment.body); + self.push(format!("void {name}$init() {{\n")); - self.lower_body_expressions(&assignment.body); + self.lower_body_expressions(name, &assignment.body); self.push(format!( "{name} = {};\n}}\n\n", assignment.body.return_value_id(), @@ -139,31 +157,58 @@ impl<'h> Context<'h> { fn lower_function_declarations(&mut self) { for (name, function) in &self.mono.functions { + if let BodyOrBuiltin::Body(body) = &function.body { + self.lower_lambda_declarations_in(name, body); + } + self.lower_function_signature(name, function); self.push(";\n"); - - // self.lower_type(&Type::Function { - // parameter_types: function - // .parameters - // .iter() - // .map(|it| it.type_.clone()) - // .collect(), - // return_type: Box::new(function.return_type.clone()), - // }); - // self.push(format!( - // " {id} = {{ .closure = NULL, .function = {id}_function }};", - // )); - // self.push("\n"); } } + fn lower_lambda_declarations_in(&mut self, declaration_name: &str, body: &'h Body) { + Self::visit_lambdas_inside_body(body, &mut |id, lambda| { + self.push(format!( + "typedef struct {declaration_name}$lambda{id}_closure {declaration_name}$lambda{id}_closure;\n", + )); + + self.lower_lambda_signature(declaration_name, id, lambda); + self.push(";\n"); + }); + } fn lower_function_definitions(&mut self) { for (name, function) in &self.mono.functions { + if let BodyOrBuiltin::Body(body) = &function.body { + self.lower_lambda_definitions_in(name, body); + } + self.lower_function_signature(name, function); self.push(" {\n"); - self.lower_body_or_builtin(function); + self.lower_body_or_builtin(name, function); self.push("}\n\n"); } } + fn lower_lambda_definitions_in(&mut self, declaration_name: &str, body: &'h Body) { + Self::visit_lambdas_inside_body(body, &mut |id, lambda| { + let closure = lambda.closure_with_types(body); + + self.push(format!("struct {declaration_name}$lambda{id}_closure {{")); + for (id, type_) in &closure { + self.push(format!("{type_}* {id}; ")); + } + self.push("};\n"); + + self.lower_lambda_signature(declaration_name, id, lambda); + self.push(" {\n"); + self.push(format!( + "{declaration_name}$lambda{id}_closure* closure = raw_closure;\n" + )); + for (id, type_) in &closure { + self.push(format!("{type_}* {id} = closure->{id};\n")); + } + self.lower_body(declaration_name, &lambda.body); + self.push("}\n"); + }); + } fn lower_function_signature(&mut self, name: &str, function: &Function) { self.push(format!("{}* {name}(", &function.return_type)); for (index, parameter) in function.parameters.iter().enumerate() { @@ -174,7 +219,39 @@ impl<'h> Context<'h> { } self.push(")"); } - fn lower_body_or_builtin(&mut self, function: &Function) { + fn lower_lambda_signature(&mut self, declaration_name: &str, id: Id, lambda: &Lambda) { + self.push(format!( + "{}* {declaration_name}$lambda{id}_function(void* raw_closure", + &lambda.body.return_type() + )); + for parameter in lambda.parameters.iter() { + self.push(format!(", {}* {}", ¶meter.type_, parameter.id)); + } + self.push(")"); + } + + fn visit_lambdas_inside_body(body: &'h Body, visitor: &mut impl FnMut(Id, &'h Lambda)) { + for (id, _, expression) in &body.expressions { + match &expression.kind { + ExpressionKind::Int(_) + | ExpressionKind::Text(_) + | ExpressionKind::CreateStruct { .. } + | ExpressionKind::StructAccess { .. } + | ExpressionKind::CreateEnum { .. } + | ExpressionKind::GlobalAssignmentReference(_) + | ExpressionKind::LocalReference(_) + | ExpressionKind::CallFunction { .. } + | ExpressionKind::CallLambda { .. } + | ExpressionKind::Switch { .. } => {} + ExpressionKind::Lambda(lambda) => { + Self::visit_lambdas_inside_body(&lambda.body, visitor); + visitor(*id, lambda); + } + } + } + } + + fn lower_body_or_builtin(&mut self, declaration_name: &str, function: &Function) { match &function.body { BodyOrBuiltin::Builtin { builtin_function, @@ -444,24 +521,24 @@ impl<'h> Context<'h> { )), } } - BodyOrBuiltin::Body(body) => self.lower_body(body), + BodyOrBuiltin::Body(body) => self.lower_body(declaration_name, body), } } - fn lower_body(&mut self, body: &Body) { - self.lower_body_expressions(body); + fn lower_body(&mut self, declaration_name: &str, body: &Body) { + self.lower_body_expressions(declaration_name, body); self.push(format!("return {};", body.return_value_id())); } - fn lower_body_expressions(&mut self, body: &Body) { + fn lower_body_expressions(&mut self, declaration_name: &str, body: &Body) { for (id, name, expression) in &body.expressions { if let Some(name) = name { self.push(format!("// {name}\n")); } - self.lower_expression(*id, expression); - self.push("\n"); + self.lower_expression(declaration_name, *id, expression); + self.push("\n\n"); } } - fn lower_expression(&mut self, id: Id, expression: &Expression) { + fn lower_expression(&mut self, declaration_name: &str, id: Id, expression: &Expression) { match &expression.kind { ExpressionKind::Int(int) => { self.push(format!( @@ -511,20 +588,13 @@ impl<'h> Context<'h> { self.push(format!("\n{id}->value.{variant} = {value};")); } } - // ExpressionKind::Lambda(lambda) => { - // self.push("{ .closure = {"); - // for id in lambda.closure().iter().sorted() { - // self.push(format!(".{id} = {id}; ")); - // } - // self.push(format!("}}, .function = {id}_function }};")); - // } ExpressionKind::GlobalAssignmentReference(assignment) => { self.push(format!("{}* {id} = {assignment};", &expression.type_)); } ExpressionKind::LocalReference(referenced_id) => { self.push(format!("{}* {id} = {referenced_id};", &expression.type_)); } - ExpressionKind::Call { + ExpressionKind::CallFunction { function, arguments, } => { @@ -537,6 +607,16 @@ impl<'h> Context<'h> { } self.push(");"); } + ExpressionKind::CallLambda { lambda, arguments } => { + self.push(format!( + "{}* {id} = {lambda}->function({lambda}->closure", + &expression.type_ + )); + for argument in arguments.iter() { + self.push(format!(", {argument}")); + } + self.push(");"); + } ExpressionKind::Switch { value, enum_, @@ -565,7 +645,7 @@ impl<'h> Context<'h> { )); } - self.lower_body_expressions(&case.body); + self.lower_body_expressions(declaration_name, &case.body); self.push(format!("{id} = {};\n", case.body.return_value_id())); @@ -573,6 +653,22 @@ impl<'h> Context<'h> { } self.push("}"); } + ExpressionKind::Lambda(lambda) => { + self.push(format!("{declaration_name}$lambda{id}_closure* {id}_closure = malloc(sizeof({declaration_name}$lambda{id}_closure));\n",)); + for referenced_id in lambda.closure().iter().sorted() { + self.push(format!( + "{id}_closure->{referenced_id} = {referenced_id};\n" + )); + } + self.push(format!( + "{type_}* {id} = malloc(sizeof({type_}));", + type_ = &expression.type_, + )); + self.push(format!("{id}->closure = {id}_closure;")); + self.push(format!( + "{id}->function = {declaration_name}$lambda{id}_function;", + )); + } } } diff --git a/compiler_v4/src/type_solver/values.rs b/compiler_v4/src/type_solver/values.rs index 8618dfedf..d9901be34 100644 --- a/compiler_v4/src/type_solver/values.rs +++ b/compiler_v4/src/type_solver/values.rs @@ -147,7 +147,7 @@ impl TryFrom for SolverType { Type::Named(named_type) => SolverValue::try_from(named_type).map(SolverType::Value), Type::Parameter(parameter_type) => Ok(SolverVariable::from(parameter_type).into()), Type::Self_ { .. } => todo!(), - Type::Error => Err(()), + Type::Function(_) | Type::Error => Err(()), } } } diff --git a/packages_v5/example.candy b/packages_v5/example.candy index 6f8ae94e2..230b54d33 100644 --- a/packages_v5/example.candy +++ b/packages_v5/example.candy @@ -300,6 +300,10 @@ fun main() Int { let list = listOf(0, 1).insert(1, 2).replace(0, 3).removeAt(2) print("Length: {list.length().toText()}") print("[{list.get(0).toText()}, {list.get(1).toText()}, {list.get(2).toText()}]") + + let foo = 123 + let addCaptured = (x: Int) { x.add(foo) } + print("addCaptured(1) = {addCaptured(1).toText()}") 0 }