Skip to content

Commit

Permalink
Fix bug
Browse files Browse the repository at this point in the history
  • Loading branch information
Rigidity committed Jul 26, 2024
1 parent ad7a304 commit 4baf36e
Show file tree
Hide file tree
Showing 13 changed files with 246 additions and 126 deletions.
18 changes: 8 additions & 10 deletions crates/rue-compiler/src/compiler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -120,20 +120,18 @@ impl<'a> Compiler<'a> {
None
}

fn type_name(&self, type_id: TypeId, debug: bool) -> String {
fn type_name(&self, type_id: TypeId) -> String {
let mut names = HashMap::new();

if !debug {
for &scope_id in &self.scope_stack {
for type_id in self.db.scope(scope_id).local_types() {
if let Some(name) = self.db.scope(scope_id).type_name(type_id) {
names.insert(type_id, name.to_string());
}
for &scope_id in &self.scope_stack {
for type_id in self.db.scope(scope_id).local_types() {
if let Some(name) = self.db.scope(scope_id).type_name(type_id) {
names.insert(type_id, name.to_string());
}
}
}

self.ty.stringify_named(type_id, names, debug)
self.ty.stringify_named(type_id, names)
}

fn type_check(&mut self, from: TypeId, to: TypeId, range: TextRange) {
Expand All @@ -146,7 +144,7 @@ impl<'a> Compiler<'a> {

if comparison > Comparison::Assignable {
self.db.error(
ErrorKind::TypeMismatch(self.type_name(from, false), self.type_name(to, false)),
ErrorKind::TypeMismatch(self.type_name(from), self.type_name(to)),
range,
);
}
Expand All @@ -162,7 +160,7 @@ impl<'a> Compiler<'a> {

if comparison > Comparison::Castable {
self.db.error(
ErrorKind::CastMismatch(self.type_name(from, false), self.type_name(to, false)),
ErrorKind::CastMismatch(self.type_name(from), self.type_name(to)),
range,
);
}
Expand Down
4 changes: 2 additions & 2 deletions crates/rue-compiler/src/compiler/expr/binary_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -152,15 +152,15 @@ impl Compiler<'_> {

if self.ty.compare(lhs.type_id, self.ty.std().bytes) > Comparison::Castable {
self.db.error(
ErrorKind::NonAtomEquality(self.type_name(lhs.type_id, false)),
ErrorKind::NonAtomEquality(self.type_name(lhs.type_id)),
text_range,
);
is_atom = false;
}

if self.ty.compare(rhs.type_id, self.ty.std().bytes) > Comparison::Castable {
self.db.error(
ErrorKind::NonAtomEquality(self.type_name(rhs.type_id, false)),
ErrorKind::NonAtomEquality(self.type_name(rhs.type_id)),
text_range,
);
is_atom = false;
Expand Down
4 changes: 2 additions & 2 deletions crates/rue-compiler/src/compiler/expr/field_access_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ impl Compiler<'_> {
self.db.error(
ErrorKind::InvalidFieldAccess(
field_name.to_string(),
self.type_name(old_value.type_id, false),
self.type_name(old_value.type_id),
),
field_name.text_range(),
);
Expand All @@ -132,7 +132,7 @@ impl Compiler<'_> {
self.db.error(
ErrorKind::InvalidFieldAccess(
field_name.to_string(),
self.type_name(old_value.type_id, false),
self.type_name(old_value.type_id),
),
field_name.text_range(),
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ impl Compiler<'_> {
if let Some(callee) = callee.as_ref() {
if function_type.is_none() {
self.db.error(
ErrorKind::UncallableType(self.type_name(callee.type_id, false)),
ErrorKind::UncallableType(self.type_name(callee.type_id)),
call.callee().unwrap().syntax().text_range(),
);
}
Expand Down
13 changes: 5 additions & 8 deletions crates/rue-compiler/src/compiler/expr/guard_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,7 @@ impl Compiler<'_> {

let Ok(check) = self.ty.check(expr.type_id, rhs) else {
self.db.error(
ErrorKind::RecursiveTypeCheck(
self.type_name(expr.type_id, false),
self.type_name(rhs, false),
),
ErrorKind::RecursiveTypeCheck(self.type_name(expr.type_id), self.type_name(rhs)),
guard.syntax().text_range(),
);
return self.unknown();
Expand All @@ -40,17 +37,17 @@ impl Compiler<'_> {
Check::True => {
self.db.warning(
WarningKind::UnnecessaryTypeCheck(
self.type_name(expr.type_id, false),
self.type_name(rhs, false),
self.type_name(expr.type_id),
self.type_name(rhs),
),
guard.syntax().text_range(),
);
}
Check::False => {
self.db.error(
ErrorKind::ImpossibleTypeCheck(
self.type_name(expr.type_id, false),
self.type_name(rhs, false),
self.type_name(expr.type_id),
self.type_name(rhs),
),
guard.syntax().text_range(),
);
Expand Down
6 changes: 2 additions & 4 deletions crates/rue-compiler/src/compiler/expr/initializer_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,17 +74,15 @@ impl Compiler<'_> {
}
} else {
self.db.error(
ErrorKind::InvalidEnumVariantInitializer(
self.type_name(ty.unwrap(), false),
),
ErrorKind::InvalidEnumVariantInitializer(self.type_name(ty.unwrap())),
initializer.path().unwrap().syntax().text_range(),
);
self.unknown()
}
}
Some(_) => {
self.db.error(
ErrorKind::UninitializableType(self.type_name(ty.unwrap(), false)),
ErrorKind::UninitializableType(self.type_name(ty.unwrap())),
initializer.path().unwrap().syntax().text_range(),
);
self.unknown()
Expand Down
2 changes: 1 addition & 1 deletion crates/rue-compiler/src/compiler/expr/path_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ impl Compiler<'_> {
if let Type::Variant(variant) = self.ty.get(type_id).clone() {
if variant.field_names.is_some() {
self.db.error(
ErrorKind::InvalidEnumVariantReference(self.type_name(type_id, false)),
ErrorKind::InvalidEnumVariantReference(self.type_name(type_id)),
text_range,
);
}
Expand Down
2 changes: 1 addition & 1 deletion crates/rue-compiler/src/compiler/path.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ impl Compiler<'_> {
Path::Type(type_id) => {
let Type::Enum(enum_type) = self.ty.get(type_id) else {
self.db.error(
ErrorKind::InvalidTypePath(self.type_name(type_id, false)),
ErrorKind::InvalidTypePath(self.type_name(type_id)),
name.text_range(),
);
return None;
Expand Down
111 changes: 65 additions & 46 deletions crates/rue-typing/src/comparison.rs
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,51 @@ pub(crate) fn compare_type(
max(first, rest)
}

// Unions can be assigned to anything so long as each of the items in the union are also.
(Type::Union(items), _) => {
let items = items.clone();
let mut result = Comparison::Assignable;

let mut any_castable = false;

for item in items {
let cmp = compare_type(db, item, rhs, ctx);
result = max(result, cmp);

if compare_type(db, rhs, item, ctx) <= Comparison::Castable {
any_castable = true;
}
}

if result == Comparison::Incompatible && any_castable {
Comparison::Superset
} else {
result
}
}

// Anything can be assigned to a union so long as it's assignable to at least one of the items.
(_, Type::Union(items)) => {
let items = items.clone();
let mut result = Comparison::Incompatible;
let mut any_incompatible = false;

for item in &items {
let cmp = compare_type(db, lhs, *item, ctx);
result = min(result, cmp);

if cmp == Comparison::Incompatible {
any_incompatible = true;
}
}

if any_incompatible && result == Comparison::Superset {
Comparison::Incompatible
} else {
max(result, Comparison::Assignable)
}
}

// We need to push substititons onto the stack in order to accurately compare them.
(Type::Lazy(lazy), _) => {
ctx.substitution_stack
Expand Down Expand Up @@ -267,7 +312,7 @@ pub(crate) fn compare_type(

// Variants can be assigned to enums if the structure is assignable and it's the same enum.
(Type::Variant(variant), Type::Enum(ty)) => {
let comparison = compare_type(db, variant.type_id, ty.type_id, ctx);
let comparison = compare_type(db, lhs, ty.type_id, ctx);

if variant.original_enum_type_id == ty.original_type_id {
max(comparison, Comparison::Assignable)
Expand Down Expand Up @@ -306,51 +351,6 @@ pub(crate) fn compare_type(
),
(Type::Callable(..), _) => compare_type(db, lhs, db.std().any, ctx),

// Unions can be assigned to anything so long as each of the items in the union are also.
(Type::Union(items), _) => {
let items = items.clone();
let mut result = Comparison::Assignable;

let mut any_castable = false;

for item in items {
let cmp = compare_type(db, item, rhs, ctx);
result = max(result, cmp);

if compare_type(db, rhs, item, ctx) <= Comparison::Castable {
any_castable = true;
}
}

if result == Comparison::Incompatible && any_castable {
Comparison::Superset
} else {
result
}
}

// Anything can be assigned to a union so long as it's assignable to at least one of the items.
(_, Type::Union(items)) => {
let items = items.clone();
let mut result = Comparison::Incompatible;
let mut any_incompatible = false;

for item in items {
let cmp = compare_type(db, lhs, item, ctx);
result = min(result, cmp);

if cmp == Comparison::Incompatible {
any_incompatible = true;
}
}

if any_incompatible && result == Comparison::Superset {
Comparison::Incompatible
} else {
max(result, Comparison::Assignable)
}
}

// Generics are resolved by looking up the substitution in the stack.
// If we're infering, we'll push the substitution onto the proper generic stack frame.
(_, Type::Generic) => {
Expand Down Expand Up @@ -816,4 +816,23 @@ mod tests {
assert_eq!(db.compare(list, list), Comparison::Equal);
assert_eq!(db.compare(pair, list), Comparison::Assignable);
}

#[test]
fn test_compare_pair_union() {
let mut db = TypeSystem::new();
let types = db.std();

let pair_enum = db.alloc(Type::Pair(types.int, types.nil));
let pair_enum = db.alloc(Type::Pair(types.int, pair_enum));
let zero = db.alloc(Type::Value(BigInt::ZERO));
let pair_enum = db.alloc(Type::Pair(zero, pair_enum));

let int_enum = db.alloc(Type::Pair(types.int, types.nil));
let one = db.alloc(Type::Value(BigInt::one()));
let int_enum = db.alloc(Type::Pair(one, int_enum));

let union = db.alloc(Type::Union(vec![pair_enum, int_enum]));

assert_eq!(db.compare(pair_enum, union), Comparison::Assignable);
}
}
Loading

0 comments on commit 4baf36e

Please sign in to comment.