From 15b8c4ff0c3963320037befd139f9155bd31d248 Mon Sep 17 00:00:00 2001 From: Rigidity Date: Tue, 23 Jul 2024 15:19:57 -0400 Subject: [PATCH] Temp 9 --- .../src/compiler/expr/function_call_expr.rs | 80 ++++++++++--------- crates/rue-typing/src/lib.rs | 2 +- crates/rue-typing/src/semantic_types.rs | 26 ++++++ 3 files changed, 71 insertions(+), 37 deletions(-) diff --git a/crates/rue-compiler/src/compiler/expr/function_call_expr.rs b/crates/rue-compiler/src/compiler/expr/function_call_expr.rs index e53eb4a..b39f06f 100644 --- a/crates/rue-compiler/src/compiler/expr/function_call_expr.rs +++ b/crates/rue-compiler/src/compiler/expr/function_call_expr.rs @@ -2,13 +2,9 @@ use std::collections::HashMap; use rowan::TextRange; use rue_parser::{AstNode, FunctionCallExpr}; +use rue_typing::{deconstruct_items, unwrap_list, Callable, Rest, Semantics, Type, TypeId}; -use crate::{ - compiler::Compiler, - hir::Hir, - value::{FunctionType, Rest, Type, Value}, - ErrorKind, -}; +use crate::{compiler::Compiler, hir::Hir, value::Value, ErrorKind}; impl Compiler<'_> { pub fn compile_function_call_expr(&mut self, call: &FunctionCallExpr) -> Value { @@ -21,11 +17,16 @@ impl Compiler<'_> { let function_type = callee .as_ref() - .and_then(|callee| match self.db.ty(callee.type_id).clone() { - Type::Function(function_type) => Some(function_type), + .and_then(|callee| match self.ty.get(callee.type_id).clone() { + Type::Callable(function_type) => Some(function_type), _ => None, }); + let parameter_types = function_type.as_ref().map(|ty| { + deconstruct_items(self.ty, ty.parameters, ty.parameter_names.len(), ty.rest) + .expect("invalid function type") + }); + // Make sure the callee is callable, if present. if let Some(callee) = callee.as_ref() { if function_type.is_none() { @@ -51,16 +52,23 @@ impl Compiler<'_> { .unwrap_or(false); if let Some(function_type) = &function_type { - self.check_argument_length(function_type, len, call.syntax().text_range()); + self.check_argument_length( + function_type, + parameter_types.clone().unwrap(), + len, + call.syntax().text_range(), + ); } for (i, arg) in call_args.iter().enumerate() { // Determine the expected type. let expected_type = function_type.as_ref().and_then(|ty| { - if i < ty.param_types.len() { - Some(ty.param_types[i]) + let parameter_types = parameter_types.as_ref().unwrap(); + + if i < parameter_types.len() { + Some(parameter_types[i]) } else if ty.rest == Rest::Spread { - self.db.unwrap_list(*ty.param_types.last().unwrap()) + unwrap_list(self.ty, *parameter_types.last().unwrap()) } else { None } @@ -90,29 +98,31 @@ impl Compiler<'_> { continue; }; + let parameter_types = parameter_types.as_ref().unwrap(); + if last && spread { if function.rest != Rest::Spread { self.db.error( ErrorKind::UnsupportedFunctionSpread, call_args[i].syntax().text_range(), ); - } else if i >= function.param_types.len() - 1 { - let expected_type = *function.param_types.last().unwrap(); + } else if i >= parameter_types.len() - 1 { + let expected_type = *parameter_types.last().unwrap(); self.type_check(type_id, expected_type, call_args[i].syntax().text_range()); } - } else if function.rest == Rest::Spread && i >= function.param_types.len() - 1 { + } else if function.rest == Rest::Spread && i >= parameter_types.len() - 1 { if let Some(inner_list_type) = - self.db.unwrap_list(*function.param_types.last().unwrap()) + unwrap_list(self.ty, *parameter_types.last().unwrap()) { self.type_check(type_id, inner_list_type, call_args[i].syntax().text_range()); - } else if i == function.param_types.len() - 1 && !spread { + } else if i == parameter_types.len() - 1 && !spread { self.db.error( ErrorKind::RequiredFunctionSpread, call_args[i].syntax().text_range(), ); } - } else if i < function.param_types.len() { - let param_type = function.param_types[i]; + } else if i < parameter_types.len() { + let param_type = parameter_types[i]; self.type_check(type_id, param_type, call_args[i].syntax().text_range()); } } @@ -126,13 +136,15 @@ impl Compiler<'_> { function_type.map_or(self.ty.std().unknown, |expected| expected.return_type); if !generic_types.is_empty() { - type_id = self.db.substitute_type(type_id, &generic_types); + type_id = self + .ty + .substitute(type_id, generic_types, Semantics::Preserve); } // Build the HIR for the function call. let hir_id = self.db.alloc_hir(Hir::FunctionCall( - callee.map_or(self.builtins.unknown_hir, |callee| callee.hir_id), + callee.map_or(self.builtins.unknown, |callee| callee.hir_id), args.iter().map(|arg| arg.hir_id).collect(), spread, )); @@ -142,43 +154,39 @@ impl Compiler<'_> { fn check_argument_length( &mut self, - function: &FunctionType, + function: &Callable, + parameter_types: Vec, length: usize, text_range: TextRange, ) { match function.rest { Rest::Nil => { - if length != function.param_types.len() { + if length != parameter_types.len() { self.db.error( - ErrorKind::ArgumentMismatch(length, function.param_types.len()), + ErrorKind::ArgumentMismatch(length, parameter_types.len()), text_range, ); } } Rest::Optional => { - if length != function.param_types.len() && length != function.param_types.len() - 1 - { + if length != parameter_types.len() && length != parameter_types.len() - 1 { self.db.error( - ErrorKind::ArgumentMismatchOptional(length, function.param_types.len()), + ErrorKind::ArgumentMismatchOptional(length, parameter_types.len()), text_range, ); } } Rest::Spread => { - if self - .db - .unwrap_list(*function.param_types.last().unwrap()) - .is_some() - { - if length < function.param_types.len() - 1 { + if unwrap_list(self.ty, *parameter_types.last().unwrap()).is_some() { + if length < parameter_types.len() - 1 { self.db.error( - ErrorKind::ArgumentMismatchSpread(length, function.param_types.len()), + ErrorKind::ArgumentMismatchSpread(length, parameter_types.len()), text_range, ); } - } else if length != function.param_types.len() { + } else if length != parameter_types.len() { self.db.error( - ErrorKind::ArgumentMismatch(length, function.param_types.len()), + ErrorKind::ArgumentMismatch(length, parameter_types.len()), text_range, ); } diff --git a/crates/rue-typing/src/lib.rs b/crates/rue-typing/src/lib.rs index 3682658..413fb64 100644 --- a/crates/rue-typing/src/lib.rs +++ b/crates/rue-typing/src/lib.rs @@ -17,6 +17,7 @@ pub use check::*; pub use comparison::*; pub use semantic_types::*; pub use standard_types::*; +pub use substitute_type::*; pub use ty::*; pub use type_path::*; pub use type_system::*; @@ -24,7 +25,6 @@ pub use type_system::*; pub(crate) use difference::difference_type; pub(crate) use replace_type::replace_type; pub(crate) use stringify::stringify_type; -pub(crate) use substitute_type::{substitute_type, Semantics}; #[cfg(test)] mod test_tools; diff --git a/crates/rue-typing/src/semantic_types.rs b/crates/rue-typing/src/semantic_types.rs index ab9d5cd..3884aec 100644 --- a/crates/rue-typing/src/semantic_types.rs +++ b/crates/rue-typing/src/semantic_types.rs @@ -165,8 +165,26 @@ pub fn deconstruct_items( Some(items) } +/// Unwraps a list type into its inner type. +pub fn unwrap_list(db: &mut TypeSystem, type_id: TypeId) -> Option { + if db.compare(db.std().nil, type_id) > Comparison::Assignable { + return None; + } + + let non_nil = db.difference(type_id, db.std().nil); + let (first, rest) = db.get_pair(non_nil)?; + + if db.compare(rest, type_id) > Comparison::Assignable { + return None; + } + + Some(first) +} + #[cfg(test)] mod tests { + use crate::alloc_list; + use super::*; #[test] @@ -276,4 +294,12 @@ mod tests { let items = deconstruct_items(&mut db, type_id, 2, Rest::Optional); assert_eq!(items, Some(vec![std.bytes32, pair])); } + + #[test] + fn test_unwrap_list() { + let mut db = TypeSystem::new(); + let std = db.std(); + let list = alloc_list(&mut db, std.public_key); + assert_eq!(unwrap_list(&mut db, list), Some(std.public_key)); + } }