Skip to content

Commit

Permalink
Merge pull request #29 from Rigidity/type-guard-redesign
Browse files Browse the repository at this point in the history
Type guard redesign
  • Loading branch information
Rigidity authored Aug 4, 2024
2 parents 7e6ad94 + a4bbcf2 commit d01a2f5
Show file tree
Hide file tree
Showing 14 changed files with 140 additions and 69 deletions.
9 changes: 3 additions & 6 deletions .github/workflows/rust.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,8 @@ jobs:
- name: Checkout
uses: actions/checkout@v4

- name: Cargo binstall
run: curl -L --proto '=https' --tlsv1.2 -sSf https://raw.githubusercontent.com/cargo-bins/cargo-binstall/main/install-from-binstall-release.sh | bash

- name: Instal cargo-workspaces
run: cargo binstall cargo-workspaces --locked -y
- name: Install cargo-workspaces
run: cargo install cargo-workspaces

- name: Run tests
run: cargo test --all-features --workspace
Expand All @@ -31,7 +28,7 @@ jobs:

- name: Unused dependencies
run: |
cargo binstall cargo-machete --locked -y
cargo install cargo-machete --locked
cargo machete
- name: Fmt
Expand Down
54 changes: 46 additions & 8 deletions crates/rue-compiler/src/compiler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use crate::{
database::{Database, HirId, ScopeId, SymbolId},
hir::{Hir, Op},
scope::Scope,
symbol::{Function, Symbol},
value::{GuardPath, Value},
ErrorKind,
};
Expand Down Expand Up @@ -46,8 +47,8 @@ pub struct Compiler<'a> {
// The type definition stack is used for calculating types referenced in types.
type_definition_stack: Vec<TypeId>,

// The type guard stack is used for overriding types in certain contexts.
type_guard_stack: Vec<HashMap<GuardPath, TypeId>>,
// Overridden symbol types due to type guards.
type_overrides: Vec<HashMap<SymbolId, TypeId>>,

// The generic type stack is used for overriding generic types that are being checked against.
generic_type_stack: Vec<HashMap<TypeId, TypeId>>,
Expand All @@ -74,7 +75,7 @@ impl<'a> Compiler<'a> {
scope_stack: vec![builtins.scope_id],
symbol_stack: Vec::new(),
type_definition_stack: Vec::new(),
type_guard_stack: Vec::new(),
type_overrides: Vec::new(),
generic_type_stack: Vec::new(),
allow_generic_inference_stack: vec![false],
is_callee: false,
Expand Down Expand Up @@ -169,13 +170,50 @@ impl<'a> Compiler<'a> {
Value::new(self.builtins.unknown, self.ty.std().unknown)
}

fn symbol_type(&self, guard_path: &GuardPath) -> Option<TypeId> {
for guards in self.type_guard_stack.iter().rev() {
if let Some(guard) = guards.get(guard_path) {
return Some(*guard);
fn build_overrides(&mut self, guards: HashMap<GuardPath, TypeId>) -> HashMap<SymbolId, TypeId> {
type GuardItem = (Vec<TypePath>, TypeId);

let mut symbol_guards: HashMap<SymbolId, Vec<GuardItem>> = HashMap::new();

for (guard_path, type_id) in guards {
symbol_guards
.entry(guard_path.symbol_id)
.or_default()
.push((guard_path.items, type_id));
}

let mut overrides = HashMap::new();

for (symbol_id, mut items) in symbol_guards {
// Order by length.
items.sort_by_key(|(items, _)| items.len());

let mut type_id = self.symbol_type(symbol_id);

for (path_items, new_type_id) in items {
type_id = self.ty.replace(type_id, new_type_id, &path_items);
}

overrides.insert(symbol_id, type_id);
}

overrides
}

fn symbol_type(&self, symbol_id: SymbolId) -> TypeId {
for guards in self.type_overrides.iter().rev() {
if let Some(type_id) = guards.get(&symbol_id) {
return *type_id;
}
}
None

match self.db.symbol(symbol_id) {
Symbol::Unknown | Symbol::Module(..) => unreachable!(),
Symbol::Function(Function { type_id, .. })
| Symbol::InlineFunction(Function { type_id, .. })
| Symbol::Parameter(type_id) => *type_id,
Symbol::Let(value) | Symbol::Const(value) | Symbol::InlineConst(value) => value.type_id,
}
}

fn scope(&self) -> &Scope {
Expand Down
14 changes: 8 additions & 6 deletions crates/rue-compiler/src/compiler/block.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ impl Compiler<'_> {

// Push the type guards onto the stack.
// This will be popped in reverse order later after all statements have been lowered.
self.type_guard_stack.push(else_guards);
let overrides = self.build_overrides(else_guards);
self.type_overrides.push(overrides);

statements.push(Statement::If(condition_hir, then_hir));
}
Expand Down Expand Up @@ -103,8 +104,8 @@ impl Compiler<'_> {
// If the condition is false, we raise an error.
// So we can assume that the condition is true from this point on.
// This will be popped in reverse order later after all statements have been lowered.

self.type_guard_stack.push(condition.then_guards());
let overrides = self.build_overrides(condition.then_guards());
self.type_overrides.push(overrides);

let not_condition = self.db.alloc_hir(Hir::Op(Op::Not, condition.hir_id));
let raise = self.db.alloc_hir(Hir::Raise(None));
Expand All @@ -126,7 +127,8 @@ impl Compiler<'_> {
assume_stmt.syntax().text_range(),
);

self.type_guard_stack.push(expr.then_guards());
let overrides = self.build_overrides(expr.then_guards());
self.type_overrides.push(overrides);
statements.push(Statement::Assume);
}
}
Expand Down Expand Up @@ -158,7 +160,7 @@ impl Compiler<'_> {
body = value;
}
Statement::If(condition, then_block) => {
self.type_guard_stack.pop().unwrap();
self.type_overrides.pop().unwrap();

body = Value::new(
self.db
Expand All @@ -167,7 +169,7 @@ impl Compiler<'_> {
);
}
Statement::Assume => {
self.type_guard_stack.pop().unwrap();
self.type_overrides.pop().unwrap();
}
}
}
Expand Down
26 changes: 18 additions & 8 deletions crates/rue-compiler/src/compiler/expr/binary_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ impl Compiler<'_> {
let else_type = self.ty.difference(rhs.type_id, self.ty.std().nil);
value
.guards
.insert(guard_path, Guard::new(then_type, else_type));
.insert(guard_path, Guard::new(Some(then_type), Some(else_type)));
}
}

Expand All @@ -182,7 +182,7 @@ impl Compiler<'_> {
let else_type = self.ty.difference(lhs.type_id, self.ty.std().nil);
value
.guards
.insert(guard_path, Guard::new(then_type, else_type));
.insert(guard_path, Guard::new(Some(then_type), Some(else_type)));
}
}

Expand Down Expand Up @@ -250,13 +250,14 @@ impl Compiler<'_> {
}

fn op_and(&mut self, lhs: Value, rhs: Option<&Expr>, text_range: TextRange) -> Value {
self.type_guard_stack.push(lhs.then_guards());
let overrides = self.build_overrides(lhs.then_guards());
self.type_overrides.push(overrides);

let rhs = rhs
.map(|rhs| self.compile_expr(rhs, Some(self.ty.std().bool)))
.unwrap_or_else(|| self.unknown());

self.type_guard_stack.pop().unwrap();
self.type_overrides.pop().unwrap();

self.type_check(lhs.type_id, self.ty.std().bool, text_range);
self.type_check(rhs.type_id, self.ty.std().bool, text_range);
Expand All @@ -267,19 +268,28 @@ impl Compiler<'_> {
rhs.hir_id,
self.ty.std().bool,
);
value.guards.extend(lhs.guards);
value.guards.extend(rhs.guards);
value.guards.extend(
lhs.guards
.into_iter()
.map(|(path, guard)| (path, Guard::new(guard.then_type, None))),
);
value.guards.extend(
rhs.guards
.into_iter()
.map(|(path, guard)| (path, Guard::new(guard.then_type, None))),
);
value
}

fn op_or(&mut self, lhs: &Value, rhs: Option<&Expr>, text_range: TextRange) -> Value {
self.type_guard_stack.push(lhs.then_guards());
let overrides = self.build_overrides(lhs.else_guards());
self.type_overrides.push(overrides);

let rhs = rhs
.map(|rhs| self.compile_expr(rhs, Some(self.ty.std().bool)))
.unwrap_or_else(|| self.unknown());

self.type_guard_stack.pop().unwrap();
self.type_overrides.pop().unwrap();

self.type_check(lhs.type_id, self.ty.std().bool, text_range);
self.type_check(rhs.type_id, self.ty.std().bool, text_range);
Expand Down
18 changes: 5 additions & 13 deletions crates/rue-compiler/src/compiler/expr/field_access_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ impl Compiler<'_> {
return self.unknown();
};

let mut new_value = match self.ty.get(old_value.type_id).clone() {
Type::Unknown => return self.unknown(),
match self.ty.get(old_value.type_id).clone() {
Type::Unknown => self.unknown(),
Type::Struct(ty) => {
let Some(value) = self.compile_struct_field_access(old_value, &ty, &name) else {
return self.unknown();
Expand Down Expand Up @@ -55,17 +55,9 @@ impl Compiler<'_> {
),
name.text_range(),
);
return self.unknown();
}
};

if let Some(guard_path) = new_value.guard_path.as_ref() {
if let Some(type_override) = self.symbol_type(guard_path) {
new_value.type_id = type_override;
self.unknown()
}
}

new_value
}

fn compile_pair_field_access(
Expand Down Expand Up @@ -113,7 +105,7 @@ impl Compiler<'_> {
) -> Option<Value> {
let fields =
deconstruct_items(self.ty, ty.type_id, ty.field_names.len(), ty.nil_terminated)
.expect("invalid struct type");
.unwrap();

let Some(index) = ty.field_names.get_index_of(name.text()) else {
self.db
Expand Down Expand Up @@ -157,7 +149,7 @@ impl Compiler<'_> {
.as_ref()
.map(|field_names| {
deconstruct_items(self.ty, type_id, field_names.len(), ty.nil_terminated)
.expect("invalid struct type")
.unwrap()
})
.unwrap_or_default()
} else {
Expand Down
4 changes: 3 additions & 1 deletion crates/rue-compiler/src/compiler/expr/guard_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,9 @@ impl Compiler<'_> {

if let Some(guard_path) = expr.guard_path {
let difference = self.ty.difference(expr.type_id, rhs);
value.guards.insert(guard_path, Guard::new(rhs, difference));
value
.guards
.insert(guard_path, Guard::new(Some(rhs), Some(difference)));
}

value
Expand Down
10 changes: 6 additions & 4 deletions crates/rue-compiler/src/compiler/expr/if_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,21 @@ impl Compiler<'_> {
.map(|condition| self.compile_expr(&condition, Some(self.ty.std().bool)));

if let Some(condition) = condition.as_ref() {
self.type_guard_stack.push(condition.then_guards());
let overrides = self.build_overrides(condition.then_guards());
self.type_overrides.push(overrides);
}

let then_block = if_expr
.then_block()
.map(|then_block| self.compile_block_expr(&then_block, expected_type));

if condition.is_some() {
self.type_guard_stack.pop().unwrap();
self.type_overrides.pop().unwrap();
}

if let Some(condition) = condition.as_ref() {
self.type_guard_stack.push(condition.else_guards());
let overrides = self.build_overrides(condition.else_guards());
self.type_overrides.push(overrides);
}

let expected_type =
Expand All @@ -33,7 +35,7 @@ impl Compiler<'_> {
.map(|else_block| self.compile_block_expr(&else_block, expected_type));

if condition.is_some() {
self.type_guard_stack.pop().unwrap();
self.type_overrides.pop().unwrap();
}

if let Some(condition_type) = condition.as_ref().map(|condition| condition.type_id) {
Expand Down
4 changes: 2 additions & 2 deletions crates/rue-compiler/src/compiler/expr/initializer_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ impl Compiler<'_> {
.path()
.map(|path| self.compile_path_type(&path.items(), path.syntax().text_range()));

match ty.map(|ty| self.ty.get(ty)).cloned() {
match ty.map(|ty| self.ty.get_unaliased(ty)).cloned() {
Some(Type::Struct(struct_type)) => {
let fields = deconstruct_items(
self.ty,
Expand Down Expand Up @@ -85,7 +85,7 @@ impl Compiler<'_> {
self.unknown()
}
}
Some(_) => {
Some(..) => {
self.db.error(
ErrorKind::UninitializableType(self.type_name(ty.unwrap())),
initializer.path().unwrap().syntax().text_range(),
Expand Down
14 changes: 6 additions & 8 deletions crates/rue-compiler/src/compiler/expr/path_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use crate::{
Compiler,
},
hir::Hir,
symbol::{Function, Symbol},
symbol::Symbol,
value::{GuardPath, Value},
ErrorKind,
};
Expand Down Expand Up @@ -75,18 +75,16 @@ impl Compiler<'_> {
return self.unknown();
}

let type_override = self.symbol_type(&GuardPath::new(symbol_id));
let type_id = self.symbol_type(symbol_id);
let reference = self.db.alloc_hir(Hir::Reference(symbol_id, text_range));

let mut value = match self.db.symbol(symbol_id).clone() {
Symbol::Unknown | Symbol::Module(..) => unreachable!(),
Symbol::Function(Function { type_id, .. })
| Symbol::InlineFunction(Function { type_id, .. })
| Symbol::Parameter(type_id) => Value::new(reference, type_override.unwrap_or(type_id)),
Symbol::Function(..) | Symbol::InlineFunction(..) | Symbol::Parameter(..) => {
Value::new(reference, type_id)
}
Symbol::Let(mut value) | Symbol::Const(mut value) | Symbol::InlineConst(mut value) => {
if let Some(type_id) = type_override {
value.type_id = type_id;
}
value.type_id = type_id;
value.hir_id = reference;
value
}
Expand Down
5 changes: 3 additions & 2 deletions crates/rue-compiler/src/compiler/stmt/if_stmt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,16 @@ impl Compiler<'_> {
let scope_id = self.db.alloc_scope(Scope::default());

// We can apply any type guards from the condition.
self.type_guard_stack.push(condition.then_guards());
let overrides = self.build_overrides(condition.then_guards());
self.type_overrides.push(overrides);

// Compile the then block.
self.scope_stack.push(scope_id);
let summary = self.compile_block(&then_block, expected_type);
self.scope_stack.pop().unwrap();

// Pop the type guards, since we've left the scope.
self.type_guard_stack.pop().unwrap();
self.type_overrides.pop().unwrap();

// If there's an implicit return, we want to raise an error.
// This could technically work but makes the intent of the code unclear.
Expand Down
Loading

0 comments on commit d01a2f5

Please sign in to comment.