Skip to content

Commit

Permalink
feat(sol-thir): implement unification
Browse files Browse the repository at this point in the history
  • Loading branch information
aripiprazole committed May 29, 2024
1 parent 4b798f0 commit b8cad73
Show file tree
Hide file tree
Showing 9 changed files with 241 additions and 43 deletions.
20 changes: 6 additions & 14 deletions sol-cli/src/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@ use std::{collections::HashMap, fmt::Display, path::PathBuf};
use fxhash::FxBuildHasher;
use itertools::Itertools;
use serde::{Deserialize, Serialize};
use sol_diagnostic::{Diagnostics, Report};
use sol_diagnostic::Diagnostics;
use sol_driver::RootDb;
use sol_eyre::Context;
use sol_hir::{
package::{HasManifest, Package, Version},
source::HirSource,
};
use sol_hir_lowering::hir_lower;
use sol_syntax::Source;
use sol_syntax::{parse, Source};
use sol_vfs::SourceFile;

#[derive(Serialize, Deserialize, Debug, Clone)]
Expand All @@ -29,7 +29,7 @@ pub struct Manifest<'db> {
pub root_folder: PathBuf,
pub soruce_folder: PathBuf,
pub config: Config,
pub diagnostics: im::HashSet<Report, FxBuildHasher>,
pub diagnostics: im::Vector<sol_diagnostic::Diagnostic, FxBuildHasher>,
}

impl<'db> Manifest<'db> {
Expand Down Expand Up @@ -78,6 +78,9 @@ impl<'db> Manifest<'db> {

let file = SourceFile::new(self.db, path, name, contents);
let cst = sol_syntax::parse(self.db, file);
let diagnostics = sol_syntax::parse::accumulated::<Diagnostics>(self.db, file);

self.diagnostics.extend(diagnostics);

Ok(cst)
}
Expand Down Expand Up @@ -121,17 +124,6 @@ impl<'db> Manifest<'db> {
let mut files = im::HashMap::default();
for package in self.db.all_packages() {
for file in package.all_files(self.db) {
// Gets the diagnostics from the CST
let diagnostics = file
.errors(self.db)
.iter()
.map(|error| error.diagnostic(self.db))
.map(Report::new)
.collect_vec();

// Add syntax errors' diagnostics to the manifest
self.diagnostics.extend(diagnostics);

let hir = hir_lower(self.db, package, file);
let diagnostics = hir_lower::accumulated::<Diagnostics>(self.db, package, file);

Expand Down
23 changes: 18 additions & 5 deletions sol-diagnostic/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#![feature(stmt_expr_attributes)]

use std::sync::Arc;
use std::{fmt::Debug, sync::Arc};

use miette::{MietteError, MietteSpanContents, SpanContents};
use salsa::DbWithJar;
Expand All @@ -15,7 +15,7 @@ pub struct Jar(crate::Diagnostics);
///
/// This is used to store diagnostics, accross salsa revisions.
#[salsa::accumulator]
pub struct Diagnostics(Arc<sol_eyre::Report>);
pub struct Diagnostics(Diagnostic);

/// A result type that uses [`Diagnostic`] as the error type.
pub type Result<T, E = Diagnostic> = std::result::Result<T, E>;
Expand All @@ -29,7 +29,7 @@ pub trait UnwrapOrReport<T: Default> {
impl<T: Default> UnwrapOrReport<T> for Result<T> {
fn unwrap_or_report(self, db: &dyn DiagnosticDb) -> T {
self.unwrap_or_else(|diagnostic| {
Diagnostics::push(db, diagnostic.0.clone());
Diagnostics::push(db, Diagnostic(diagnostic.0.clone()));
T::default()
})
}
Expand All @@ -54,9 +54,22 @@ impl<T> IntoSolDiagnostic<T> for Result<T, sol_eyre::Report> {
}
}

#[derive(Clone, Debug)]
/// Fail with miette error.
pub fn fail<U, E: Into<miette::Report>>(report: E) -> Result<U> {
Err(Diagnostic(Arc::new(sol_eyre::Report::Miette(
report.into(),
))))
}

#[derive(Clone)]
pub struct Diagnostic(pub Arc<sol_eyre::Report>);

impl Debug for Diagnostic {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
Debug::fmt(&self.0, f)
}
}

impl Eq for Diagnostic {}

impl From<sol_eyre::Report> for Diagnostic {
Expand All @@ -76,7 +89,7 @@ impl PartialEq for Diagnostic {

/// Report miette error to the diagnostic accumulator.
pub fn report_error<T: Into<sol_eyre::Report>>(db: &dyn DiagnosticDb, report: T) {
Diagnostics::push(db, Arc::new(report.into()));
Diagnostics::push(db, Diagnostic(Arc::new(report.into())));
}

pub trait DiagnosticDb: DbWithJar<Jar> {}
Expand Down
13 changes: 7 additions & 6 deletions sol-driver/src/suite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use owo_colors::{
};
use salsa_2022::DebugWithDb;
use similar::{ChangeTag, TextDiff};
use sol_diagnostic::Diagnostic;
use sol_hir::{fmt::HirFormatter, source::HirElement};
use sol_typer::TypeTable;

Expand Down Expand Up @@ -90,11 +91,14 @@ pub fn run_test_suite(
expect: &str,
f: impl FnOnce(RootDb, SourceCode, Expect) -> sol_eyre::Result<()>,
) {
let _ = env_logger::builder()
bupropion::install(bupropion::BupropionHandlerOpts::new).unwrap();

env_logger::builder()
.is_test(true)
.filter_level(log::LevelFilter::Debug)
.filter_module("salsa_2022", log::LevelFilter::Off)
.try_init();
.try_init()
.unwrap();

let db = RootDb::default();
let mut output = Vec::new();
Expand Down Expand Up @@ -123,10 +127,7 @@ pub fn run_test_suite(
}

/// Groups the errors by file.
pub fn push_fancy_errors(
output: Expect,
outputs: &[Vec<Arc<sol_eyre::Report>>],
) -> sol_eyre::Result<()> {
pub fn push_fancy_errors(output: Expect, outputs: &[Vec<Diagnostic>]) -> sol_eyre::Result<()> {
writeln!(output, "Errors:")?;
for error in outputs.iter().flatten() {
writeln!(output, "{error:?}")?;
Expand Down
4 changes: 2 additions & 2 deletions sol-syntax/src/error_handling.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::sync::Arc;

use miette::{SourceOffset, SourceSpan};
use sol_diagnostic::{Diagnostics, TextSource};
use sol_diagnostic::{report_error, TextSource};

use crate::Source;

Expand Down Expand Up @@ -85,7 +85,7 @@ impl Source {
#[salsa::tracked]
pub fn validated(self, db: &dyn crate::ParseDb) -> Source {
for error in self.errors(db) {
Diagnostics::push(db, Arc::new(error.into()));
report_error(db, error);
}

self
Expand Down
4 changes: 2 additions & 2 deletions sol-thir-lowering/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ pub fn thir_eval(db: &dyn ThirLoweringDb, env: Env, term: Term) -> sol_diagnosti
}
Term::App(callee, argument) => db
.thir_eval(env, *callee)?
.apply_to_spine(db.thir_eval(env, *argument)?),
.apply_to_spine(db, db.thir_eval(env, *argument)?)?,
Term::Pi(name, implicitness, domain, codomain) => Value::Pi(Pi {
name,
implicitness,
Expand Down Expand Up @@ -128,7 +128,7 @@ pub fn thir_quote(
})
}

let (location, value) = value.force(db);
let (location, value) = value.force(db)?;

location
.map(|location| {
Expand Down
2 changes: 2 additions & 0 deletions sol-thir/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#![feature(fn_traits)]
#![feature(unboxed_closures)]
#![feature(trait_upcasting)]
#![feature(box_patterns)]

use debruijin::Level;
use salsa::DbWithJar;
Expand Down Expand Up @@ -43,6 +44,7 @@ pub struct Jar(
shared::Context,
shared::Context_create_new_value,
shared::Context_insert_new_binder,
shared::Context_increase_level,
debruijin::Indices,
debruijin::Level,
debruijin::Level_as_idx,
Expand Down
6 changes: 6 additions & 0 deletions sol-thir/src/shared.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,12 @@ impl Context {

#[salsa::tracked]
impl Context {
#[salsa::tracked]
pub fn increase_level(self, db: &dyn ThirDb) -> Context {
let lvl = self.lvl(db).increase(db);
Context::new(db, lvl, self.locals(db), self.env(db))
}

#[salsa::tracked]
pub fn create_new_value(self, db: &dyn ThirDb, name: Definition, value: Value) -> Context {
todo!()
Expand Down
Loading

0 comments on commit b8cad73

Please sign in to comment.