From e1a7f04881c22c777c2ba5e3ec0bf16f1d480844 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Elizabeth=20Pa=C5=BA?= Date: Wed, 7 Feb 2024 16:00:19 +0100 Subject: [PATCH] Implement named constructor --- ante-ls/src/main.rs | 8 +- examples/nameresolution/named_constructor.an | 23 +++++ examples/typechecking/named_constructor.an | 13 +++ src/error/mod.rs | 17 +++- src/hir/monomorphisation.rs | 9 +- src/nameresolution/mod.rs | 100 ++++++++++++++++++- src/parser/ast.rs | 12 ++- src/parser/mod.rs | 29 +++--- src/parser/pretty_printer.rs | 11 +- src/types/typechecker.rs | 4 +- 10 files changed, 200 insertions(+), 26 deletions(-) create mode 100644 examples/nameresolution/named_constructor.an create mode 100644 examples/typechecking/named_constructor.an diff --git a/ante-ls/src/main.rs b/ante-ls/src/main.rs index 1bb96fc9..8fd78713 100644 --- a/ante-ls/src/main.rs +++ b/ante-ls/src/main.rs @@ -239,8 +239,12 @@ fn walk_ast<'a>(ast: &'a Ast<'a>, idx: usize) -> &'a Ast<'a> { } }, Ast::NamedConstructor(n) => { - if let Some((_, arg)) = n.args.iter().find(|(_, arg)| arg.locate().contains_index(&idx)) { - ast = arg; + let statements = match n.sequence.as_ref() { + Ast::Sequence(s) => &s.statements, + _ => unreachable!(), + }; + if let Some(stmt) = statements.iter().find(|stmt| stmt.locate().contains_index(&idx)) { + ast = stmt; } else if n.constructor.locate().contains_index(&idx) { ast = &n.constructor; } else { diff --git a/examples/nameresolution/named_constructor.an b/examples/nameresolution/named_constructor.an new file mode 100644 index 00000000..0b75a721 --- /dev/null +++ b/examples/nameresolution/named_constructor.an @@ -0,0 +1,23 @@ +type T a b = + x: a + y: b + +x = 4 + +t1 = T with x = 3, z = 5 +t2 = T with y = 3.2, x + +// Declarations should not leak from the named constructor +z = y * 2.0 + +// args: --check +// expected stderr: +// named_constructor.an:7:6 error: Missing fields: y +// t1 = T with x = 3, z = 5 +// +// named_constructor.an:7:20 error: z is not a struct field +// t1 = T with x = 3, z = 5 +// +// named_constructor.an:11:5 error: No declaration for `y` was found in scope +// z = y * 2.0 + diff --git a/examples/typechecking/named_constructor.an b/examples/typechecking/named_constructor.an new file mode 100644 index 00000000..1d0e0819 --- /dev/null +++ b/examples/typechecking/named_constructor.an @@ -0,0 +1,13 @@ +type Foo = + x: I32 + y: U64 + +x = "Hello World" + +foo = Foo with x = 1, y = 42 + +// args: --show-types +// expected stdout: +// Foo : (forall a. (I32 - U64 -> Foo can a)) +// foo : Foo +// x : String diff --git a/src/error/mod.rs b/src/error/mod.rs index a4e1aafa..a1b3efbe 100644 --- a/src/error/mod.rs +++ b/src/error/mod.rs @@ -60,6 +60,9 @@ pub enum DiagnosticKind { HandlerMissingCases(/*missing effect cases*/ Vec), ImportShadowsPreviousDefinition(/*item name*/ String), Unused(/*item name*/ String), + NotAStruct(/*struct name*/ String), + MissingFields(/*missing struct fields*/ Vec), + NotAStructField(/*field name*/ String), // // Type Checking @@ -206,6 +209,15 @@ impl Display for DiagnosticKind { DiagnosticKind::Unused(item) => { write!(f, "{item} is unused (prefix name with _ to silence this warning)") }, + DiagnosticKind::NotAStruct(name) => { + write!(f, "{} is not a struct", name) + }, + DiagnosticKind::NotAStructField(name) => { + write!(f, "{} is not a struct field", name) + }, + DiagnosticKind::MissingFields(fields) => { + write!(f, "Missing fields: {}", fields.join(", ")) + }, DiagnosticKind::TypeLengthMismatch(left, right) => { write!( f, @@ -366,7 +378,10 @@ impl DiagnosticKind { | MultipleMatchingImpls(_, _) | NoMatchingImpls(_) | MissingCase(_) - | InternalError(_) => Error, + | InternalError(_) + | NotAStruct(_) + | MissingFields(_) + | NotAStructField(_) => Error, } } } diff --git a/src/hir/monomorphisation.rs b/src/hir/monomorphisation.rs index 48cbe57b..0cebbd4a 100644 --- a/src/hir/monomorphisation.rs +++ b/src/hir/monomorphisation.rs @@ -125,7 +125,7 @@ impl<'c> Context<'c> { Assignment(assignment) => self.monomorphise_assignment(assignment), EffectDefinition(_) => todo!(), Handle(_) => todo!(), - NamedConstructor(_) => todo!(), + NamedConstructor(constructor) => self.monomorphise_named_constructor(constructor), } } @@ -1586,6 +1586,13 @@ impl<'c> Context<'c> { hir::Ast::Assignment(hir::Assignment { lhs: Box::new(lhs), rhs: Box::new(self.monomorphise(&assignment.rhs)) }) } + fn monomorphise_named_constructor(&mut self, constructor: &ast::NamedConstructor<'c>) -> hir::Ast { + match constructor.sequence.as_ref() { + ast::Ast::Sequence(sequence) => self.monomorphise_sequence(sequence), + _ => unreachable!(), + } + } + pub fn extract(ast: hir::Ast, member_index: u32, result_type: Type) -> hir::Ast { use hir::{ Ast, diff --git a/src/nameresolution/mod.rs b/src/nameresolution/mod.rs index 11f25b01..f21be407 100644 --- a/src/nameresolution/mod.rs +++ b/src/nameresolution/mod.rs @@ -1627,7 +1627,103 @@ impl<'c> Resolvable<'c> for ast::Handle<'c> { impl<'c> Resolvable<'c> for ast::NamedConstructor<'c> { fn declare(&mut self, _resolver: &mut NameResolver, _cache: &mut ModuleCache<'c>) {} - fn define(&mut self, _resolver: &mut NameResolver, _cache: &mut ModuleCache<'c>) { - todo!() + fn define(&mut self, resolver: &mut NameResolver, cache: &mut ModuleCache<'c>) { + let type_name = match self.constructor.as_ref() { + Ast::Variable(ast::Variable { kind, .. }) => kind.name(), + _ => { + // This should never happen since constructor is parsed with the `variant` parser + cache.push_diagnostic( + self.constructor.locate(), + D::InternalError("Expected consturctor field to be a Variable"), + ); + return; + }, + }; + // This will increment the use count for that type. + // It will result in it being one higher than it needs to, + // as the define pass on the sequence will do it again, + // since it ends with a FunctionCall. Is that a problem? + let type_info = match resolver.lookup_type(type_name.as_ref(), cache) { + Some(id) => &cache.type_infos[id.0], + None => { + cache.push_diagnostic(self.location, D::NotInScope("Type", type_name.as_ref().clone())); + return; + }, + }; + + // Field names in the order they appear in the type definition + let struct_fields = match &type_info.body { + TypeInfoBody::Struct(fields) => fields.iter().map(|field| &field.name), + _ => { + cache.push_diagnostic(self.constructor.locate(), D::NotAStruct(type_name.as_ref().clone())); + return; + }, + }; + let statements = match self.sequence.as_mut() { + Ast::Sequence(ast::Sequence { statements, .. }) => statements, + _ => unreachable!(), + }; + + // Fields referenced in the constructor + let defined_fields = statements + .iter() + .map(|stmt| { + let (variable, location) = match stmt { + Ast::Definition(ast::Definition { pattern, location, .. }) => (pattern.as_ref(), location), + Ast::Variable(v) => (stmt, &v.location), + _ => unreachable!(), + }; + + let name = match variable { + Ast::Variable(ast::Variable { kind: ast::VariableKind::Identifier(name), .. }) => name, + _ => unreachable!(), + }; + + (name, (variable, location)) + }) + .collect::>(); + + // Collecting missing and unknown fields, + // as well as the arguments to the constructor in a single pass + let (missing_fields, unknown_fields, args) = struct_fields.fold( + (Vec::new(), defined_fields.keys().cloned().collect::>(), Vec::new()), + |(mut mf, mut uf, mut args), field| { + if let Some(&(variable, _)) = defined_fields.get(field) { + args.push(variable.clone()); + uf.retain(|&f| f != field); + } else { + mf.push(field); + } + (mf, uf, args) + }, + ); + + let has_missing_fields = !missing_fields.is_empty(); + let has_unknown_fields = !unknown_fields.is_empty(); + + if has_missing_fields { + cache.push_diagnostic( + self.constructor.locate(), + D::MissingFields(missing_fields.into_iter().cloned().collect()), + ); + } + + for unknown_field in unknown_fields { + cache.push_diagnostic(*defined_fields[&unknown_field].1, D::NotAStructField(unknown_field.clone())); + } + + if has_missing_fields || has_unknown_fields { + return; + } + + let call = ast::Ast::function_call(self.constructor.as_ref().clone(), args, self.location); + + // We only want to keep declarations in the sequence to keep the Hir simpler + statements.retain(|s| !matches!(s, Ast::Variable(_))); + statements.push(call); + + resolver.push_scope(cache); + self.sequence.define(resolver, cache); + resolver.pop_scope(cache, false, None); } } diff --git a/src/parser/ast.rs b/src/parser/ast.rs index 40e502a3..23b51271 100644 --- a/src/parser/ast.rs +++ b/src/parser/ast.rs @@ -400,8 +400,7 @@ pub struct Handle<'a> { #[derive(Debug, Clone)] pub struct NamedConstructor<'a> { pub constructor: Box>, - pub args: Vec<(String, Ast<'a>)>, - + pub sequence: Box>, pub location: Location<'a>, pub typ: Option, } @@ -709,8 +708,13 @@ impl<'a> Ast<'a> { Ast::Handle(Handle { expression: Box::new(expression), branches, location, resumes: vec![], typ: None }) } - pub fn named_constructor(constructor: Ast<'a>, args: Vec<(String, Ast<'a>)>, location: Location<'a>) -> Ast<'a> { - Ast::NamedConstructor(NamedConstructor { constructor: Box::new(constructor), args, location, typ: None }) + pub fn named_constructor(constructor: Ast<'a>, sequence: Ast<'a>, location: Location<'a>) -> Ast<'a> { + Ast::NamedConstructor(NamedConstructor { + constructor: Box::new(constructor), + sequence: Box::new(sequence), + location, + typ: None, + }) } /// This is a bit of a hack. diff --git a/src/parser/mod.rs b/src/parser/mod.rs index 6be043de..3fd28112 100644 --- a/src/parser/mod.rs +++ b/src/parser/mod.rs @@ -458,11 +458,11 @@ parser!(function_call loc = parser!(named_constructor_expr loc = constructor <- variant; _ <- expect(Token::With); - args !<- named_constructor_args; - Ast::named_constructor(constructor, args, loc) + sequence !<- named_constructor_args; + Ast::named_constructor(constructor, sequence, loc) ); -fn named_constructor_args<'a, 'b>(input: Input<'a, 'b>) -> ParseResult<'a, 'b, Vec<(String, Ast<'b>)>> { +fn named_constructor_args<'a, 'b>(input: Input<'a, 'b>) -> ParseResult<'a, 'b, Ast<'b>> { if let Token::Indent = input[0].0 { named_constructor_block_args(input) } else { @@ -470,24 +470,27 @@ fn named_constructor_args<'a, 'b>(input: Input<'a, 'b>) -> ParseResult<'a, 'b, V } } -parser!(named_constructor_block_args loc -> 'b Vec<(String, Ast<'b>)> = +parser!(named_constructor_block_args loc -> 'b Ast<'b> = _ <- expect(Token::Indent); - args <- delimited_trailing(named_constructor_arg, expect(Token::Newline), false); + statements <- delimited_trailing(named_constructor_arg, expect(Token::Newline), false); _ !<- expect(Token::Unindent); - args + Ast::sequence(statements, loc) ); -parser!(named_constructor_inline_args loc -> 'b Vec<(String, Ast<'b>)> = - args <- delimited(named_constructor_arg, expect(Token::Comma)); - args +parser!(named_constructor_inline_args loc -> 'b Ast<'b> = + statements <- delimited(named_constructor_arg, expect(Token::Comma)); + Ast::sequence(statements, loc) ); -fn named_constructor_arg<'a, 'b>(input: Input<'a, 'b>) -> ParseResult<'a, 'b, (String, Ast<'b>)> { +fn named_constructor_arg<'a, 'b>(input: Input<'a, 'b>) -> ParseResult<'a, 'b, Ast<'b>> { let (input, ident, start) = identifier(input)?; + let field_name = Ast::variable(vec![], ident, start); let (input, maybe_expr, end) = maybe(pair(expect(Token::Equal), function_argument))(input)?; - // Desugar bar, baz into bar = bar, baz = baz - let expr = maybe_expr.map_or_else(|| Ast::variable(vec![], ident.clone(), start), |(_, expr)| expr); - Ok((input, (ident, expr), start.union(end))) + let expr = match maybe_expr { + Some((_, expr)) => Ast::definition(field_name, expr, start.union(end)), + None => field_name, + }; + Ok((input, expr, start.union(end))) } parser!(pattern_function_call loc = diff --git a/src/parser/pretty_printer.rs b/src/parser/pretty_printer.rs index 2364d1dd..097f7c54 100644 --- a/src/parser/pretty_printer.rs +++ b/src/parser/pretty_printer.rs @@ -275,7 +275,16 @@ impl<'a> Display for ast::Handle<'a> { impl<'a> Display for ast::NamedConstructor<'a> { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - let args = fmap(self.args.iter(), |(name, expr)| format!("{name} = {expr}")); + let statements = match self.sequence.as_ref() { + Ast::Sequence(ast::Sequence { statements, .. }) => statements, + _ => unreachable!(), + }; + let args = fmap(statements, |stmt| match stmt { + Ast::Definition(ast::Definition { pattern, expr, .. }) => format!("{pattern} = {expr}"), + Ast::Variable(v) => format!("{v} = {v}"), + _ => unreachable!(), + }); + write!(f, "({} with {})", self.constructor, args.join(", ")) } } diff --git a/src/types/typechecker.rs b/src/types/typechecker.rs index 28f82df2..04e8fa7c 100644 --- a/src/types/typechecker.rs +++ b/src/types/typechecker.rs @@ -2060,7 +2060,7 @@ impl<'a> Inferable<'a> for ast::Handle<'a> { } impl<'a> Inferable<'a> for ast::NamedConstructor<'a> { - fn infer_impl(&mut self, _checker: &mut ModuleCache<'a>) -> TypeResult { - todo!() + fn infer_impl(&mut self, cache: &mut ModuleCache<'a>) -> TypeResult { + self.sequence.infer_impl(cache) } }