Skip to content

Commit

Permalink
Temp 9
Browse files Browse the repository at this point in the history
  • Loading branch information
Rigidity committed Jul 23, 2024
1 parent 564e397 commit 15b8c4f
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 37 deletions.
80 changes: 44 additions & 36 deletions crates/rue-compiler/src/compiler/expr/function_call_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,9 @@ use std::collections::HashMap;

use rowan::TextRange;
use rue_parser::{AstNode, FunctionCallExpr};
use rue_typing::{deconstruct_items, unwrap_list, Callable, Rest, Semantics, Type, TypeId};

use crate::{
compiler::Compiler,
hir::Hir,
value::{FunctionType, Rest, Type, Value},
ErrorKind,
};
use crate::{compiler::Compiler, hir::Hir, value::Value, ErrorKind};

impl Compiler<'_> {
pub fn compile_function_call_expr(&mut self, call: &FunctionCallExpr) -> Value {
Expand All @@ -21,11 +17,16 @@ impl Compiler<'_> {
let function_type =
callee
.as_ref()
.and_then(|callee| match self.db.ty(callee.type_id).clone() {
Type::Function(function_type) => Some(function_type),
.and_then(|callee| match self.ty.get(callee.type_id).clone() {
Type::Callable(function_type) => Some(function_type),
_ => None,
});

let parameter_types = function_type.as_ref().map(|ty| {
deconstruct_items(self.ty, ty.parameters, ty.parameter_names.len(), ty.rest)
.expect("invalid function type")
});

// Make sure the callee is callable, if present.
if let Some(callee) = callee.as_ref() {
if function_type.is_none() {
Expand All @@ -51,16 +52,23 @@ impl Compiler<'_> {
.unwrap_or(false);

if let Some(function_type) = &function_type {
self.check_argument_length(function_type, len, call.syntax().text_range());
self.check_argument_length(
function_type,
parameter_types.clone().unwrap(),
len,
call.syntax().text_range(),
);
}

for (i, arg) in call_args.iter().enumerate() {
// Determine the expected type.
let expected_type = function_type.as_ref().and_then(|ty| {
if i < ty.param_types.len() {
Some(ty.param_types[i])
let parameter_types = parameter_types.as_ref().unwrap();

if i < parameter_types.len() {
Some(parameter_types[i])
} else if ty.rest == Rest::Spread {
self.db.unwrap_list(*ty.param_types.last().unwrap())
unwrap_list(self.ty, *parameter_types.last().unwrap())
} else {
None
}
Expand Down Expand Up @@ -90,29 +98,31 @@ impl Compiler<'_> {
continue;
};

let parameter_types = parameter_types.as_ref().unwrap();

if last && spread {
if function.rest != Rest::Spread {
self.db.error(
ErrorKind::UnsupportedFunctionSpread,
call_args[i].syntax().text_range(),
);
} else if i >= function.param_types.len() - 1 {
let expected_type = *function.param_types.last().unwrap();
} else if i >= parameter_types.len() - 1 {
let expected_type = *parameter_types.last().unwrap();
self.type_check(type_id, expected_type, call_args[i].syntax().text_range());
}
} else if function.rest == Rest::Spread && i >= function.param_types.len() - 1 {
} else if function.rest == Rest::Spread && i >= parameter_types.len() - 1 {
if let Some(inner_list_type) =
self.db.unwrap_list(*function.param_types.last().unwrap())
unwrap_list(self.ty, *parameter_types.last().unwrap())
{
self.type_check(type_id, inner_list_type, call_args[i].syntax().text_range());
} else if i == function.param_types.len() - 1 && !spread {
} else if i == parameter_types.len() - 1 && !spread {
self.db.error(
ErrorKind::RequiredFunctionSpread,
call_args[i].syntax().text_range(),
);
}
} else if i < function.param_types.len() {
let param_type = function.param_types[i];
} else if i < parameter_types.len() {
let param_type = parameter_types[i];
self.type_check(type_id, param_type, call_args[i].syntax().text_range());
}
}
Expand All @@ -126,13 +136,15 @@ impl Compiler<'_> {
function_type.map_or(self.ty.std().unknown, |expected| expected.return_type);

if !generic_types.is_empty() {
type_id = self.db.substitute_type(type_id, &generic_types);
type_id = self
.ty
.substitute(type_id, generic_types, Semantics::Preserve);
}

// Build the HIR for the function call.

let hir_id = self.db.alloc_hir(Hir::FunctionCall(
callee.map_or(self.builtins.unknown_hir, |callee| callee.hir_id),
callee.map_or(self.builtins.unknown, |callee| callee.hir_id),
args.iter().map(|arg| arg.hir_id).collect(),
spread,
));
Expand All @@ -142,43 +154,39 @@ impl Compiler<'_> {

fn check_argument_length(
&mut self,
function: &FunctionType,
function: &Callable,
parameter_types: Vec<TypeId>,
length: usize,
text_range: TextRange,
) {
match function.rest {
Rest::Nil => {
if length != function.param_types.len() {
if length != parameter_types.len() {
self.db.error(
ErrorKind::ArgumentMismatch(length, function.param_types.len()),
ErrorKind::ArgumentMismatch(length, parameter_types.len()),
text_range,
);
}
}
Rest::Optional => {
if length != function.param_types.len() && length != function.param_types.len() - 1
{
if length != parameter_types.len() && length != parameter_types.len() - 1 {
self.db.error(
ErrorKind::ArgumentMismatchOptional(length, function.param_types.len()),
ErrorKind::ArgumentMismatchOptional(length, parameter_types.len()),
text_range,
);
}
}
Rest::Spread => {
if self
.db
.unwrap_list(*function.param_types.last().unwrap())
.is_some()
{
if length < function.param_types.len() - 1 {
if unwrap_list(self.ty, *parameter_types.last().unwrap()).is_some() {
if length < parameter_types.len() - 1 {
self.db.error(
ErrorKind::ArgumentMismatchSpread(length, function.param_types.len()),
ErrorKind::ArgumentMismatchSpread(length, parameter_types.len()),
text_range,
);
}
} else if length != function.param_types.len() {
} else if length != parameter_types.len() {
self.db.error(
ErrorKind::ArgumentMismatch(length, function.param_types.len()),
ErrorKind::ArgumentMismatch(length, parameter_types.len()),
text_range,
);
}
Expand Down
2 changes: 1 addition & 1 deletion crates/rue-typing/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@ pub use check::*;
pub use comparison::*;
pub use semantic_types::*;
pub use standard_types::*;
pub use substitute_type::*;
pub use ty::*;
pub use type_path::*;
pub use type_system::*;

pub(crate) use difference::difference_type;
pub(crate) use replace_type::replace_type;
pub(crate) use stringify::stringify_type;
pub(crate) use substitute_type::{substitute_type, Semantics};

#[cfg(test)]
mod test_tools;
Expand Down
26 changes: 26 additions & 0 deletions crates/rue-typing/src/semantic_types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -165,8 +165,26 @@ pub fn deconstruct_items(
Some(items)
}

/// Unwraps a list type into its inner type.
pub fn unwrap_list(db: &mut TypeSystem, type_id: TypeId) -> Option<TypeId> {
if db.compare(db.std().nil, type_id) > Comparison::Assignable {
return None;
}

let non_nil = db.difference(type_id, db.std().nil);
let (first, rest) = db.get_pair(non_nil)?;

if db.compare(rest, type_id) > Comparison::Assignable {
return None;
}

Some(first)
}

#[cfg(test)]
mod tests {
use crate::alloc_list;

use super::*;

#[test]
Expand Down Expand Up @@ -276,4 +294,12 @@ mod tests {
let items = deconstruct_items(&mut db, type_id, 2, Rest::Optional);
assert_eq!(items, Some(vec![std.bytes32, pair]));
}

#[test]
fn test_unwrap_list() {
let mut db = TypeSystem::new();
let std = db.std();
let list = alloc_list(&mut db, std.public_key);
assert_eq!(unwrap_list(&mut db, list), Some(std.public_key));
}
}

0 comments on commit 15b8c4f

Please sign in to comment.