Skip to content

Commit

Permalink
Make constraints and errors part of the result
Browse files Browse the repository at this point in the history
  • Loading branch information
purefunctor committed Feb 15, 2024
1 parent 99379c0 commit 48deef2
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 47 deletions.
9 changes: 8 additions & 1 deletion crates/analyzer-module/src/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,18 @@ pub struct InferMap {
pub of_value_group: FxHashMap<ValueGroupId, CoreTypeId>,
}

#[derive(Debug, PartialEq, Eq)]
pub struct InferResult {
pub constraints: Vec<Constraint>,
pub errors: Vec<InferError>,
pub map: InferMap,
}

#[salsa::query_group(InferenceStorage)]
pub trait InferenceDatabase: ScopeDatabase {
#[salsa::interned]
fn intern_type(&self, t: CoreType) -> CoreTypeId;

#[salsa::invoke(rules::file_infer_query)]
fn file_infer(&self, file_id: FileId) -> Arc<InferMap>;
fn file_infer(&self, file_id: FileId) -> Arc<InferResult>;
}
16 changes: 10 additions & 6 deletions crates/analyzer-module/src/infer/rules.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use crate::{
id::InFile, infer::pretty_print, scope::ResolveInfo, surface::tree::*, InferenceDatabase,
};

use super::{Constraint, CoreType, CoreTypeId, Hint, InferError, InferMap};
use super::{Constraint, CoreType, CoreTypeId, Hint, InferError, InferMap, InferResult};

use recursive::{recursive_data_groups, recursive_let_names, recursive_value_groups};

Expand All @@ -26,14 +26,14 @@ struct InferState {
hints: Vec<Hint>,
constraints: Vec<Constraint>,
errors: Vec<InferError>,
infer_map: InferMap,
map: InferMap,
}

struct InferContext<'a> {
file_id: FileId,
arena: &'a SurfaceArena,
resolve: &'a ResolveInfo,
imported: &'a FxHashMap<FileId, Arc<InferMap>>,
imported: &'a FxHashMap<FileId, Arc<InferResult>>,
state: InferState,
}

Expand All @@ -42,7 +42,7 @@ impl<'a> InferContext<'a> {
file_id: FileId,
arena: &'a SurfaceArena,
resolve: &'a ResolveInfo,
imported: &'a FxHashMap<FileId, Arc<InferMap>>,
imported: &'a FxHashMap<FileId, Arc<InferResult>>,
) -> InferContext<'a> {
let state = InferState::default();
InferContext { file_id, arena, resolve, state, imported }
Expand Down Expand Up @@ -89,7 +89,7 @@ impl<'i, 'a> SolveContext<'i, 'a> {
}
}

pub(super) fn file_infer_query(db: &dyn InferenceDatabase, file_id: FileId) -> Arc<InferMap> {
pub(super) fn file_infer_query(db: &dyn InferenceDatabase, file_id: FileId) -> Arc<InferResult> {
let (surface, arena) = db.file_surface(file_id);
let resolve = db.file_resolve(file_id);

Expand Down Expand Up @@ -132,5 +132,9 @@ pub(super) fn file_infer_query(db: &dyn InferenceDatabase, file_id: FileId) -> A
eprintln!("{} ~ {}", u.value, pretty_print(db, t));
}

Arc::new(infer_ctx.state.infer_map)
Arc::new(InferResult {
constraints: infer_ctx.state.constraints,
errors: infer_ctx.state.errors,
map: infer_ctx.state.map,
})
}
2 changes: 1 addition & 1 deletion crates/analyzer-module/src/infer/rules/data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ impl InferContext<'_> {
},
);

self.state.infer_map.of_constructor.insert(*constructor_id, qualified_ty);
self.state.map.of_constructor.insert(*constructor_id, qualified_ty);
});
}
}
67 changes: 28 additions & 39 deletions crates/analyzer-module/src/infer/rules/value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ impl InferContext<'_> {
} else {
self.fresh_unification(db)
};
self.state.infer_map.of_value_group.insert(*value_group_id, value_ty);
self.state.map.of_value_group.insert(*value_group_id, value_ty);
}

for (value_group_id, value_declaration) in value_declarations {
Expand All @@ -34,8 +34,7 @@ impl InferContext<'_> {
value_declaration: &ValueDeclaration,
) {
self.add_hint(Hint::ValueGroup(value_group_id));
let Some(value_ty) = self.state.infer_map.of_value_group.get(&value_group_id).copied()
else {
let Some(value_ty) = self.state.map.of_value_group.get(&value_group_id).copied() else {
unreachable!("impossible: caller must insert a type!");
};

Expand Down Expand Up @@ -88,11 +87,8 @@ impl InferContext<'_> {
if let Some(constructor_resolution) =
self.resolve.per_constructor_binder.get(&binder_id)
{
if let Some(constructor_ty) = self
.state
.infer_map
.of_constructor
.get(&constructor_resolution.constructor_id)
if let Some(constructor_ty) =
self.state.map.of_constructor.get(&constructor_resolution.constructor_id)
{
let constructor_ty = self.instantiate_type(db, *constructor_ty);
let (arguments_ty, _) = self.peel_arguments(db, constructor_ty);
Expand Down Expand Up @@ -146,7 +142,7 @@ impl InferContext<'_> {
Binder::Wildcard => self.fresh_unification(db),
Binder::NotImplemented => db.intern_type(CoreType::NotImplemented),
};
self.state.infer_map.of_binder.insert(binder_id, binder_ty);
self.state.map.of_binder.insert(binder_id, binder_ty);
binder_ty
}

Expand Down Expand Up @@ -191,7 +187,7 @@ impl InferContext<'_> {
for let_name_component in let_name_components {
for let_name_id in &let_name_component {
let fresh_ty = self.fresh_unification(db);
self.state.infer_map.of_let_name.insert(*let_name_id, fresh_ty);
self.state.map.of_let_name.insert(*let_name_id, fresh_ty);
}
for let_name_id in let_name_component {
self.infer_let_name(db, let_name_id);
Expand All @@ -211,7 +207,7 @@ impl InferContext<'_> {
}

fn infer_let_name(&mut self, db: &dyn InferenceDatabase, let_name_id: LetNameId) {
let Some(fresh_ty) = self.state.infer_map.of_let_name.get(&let_name_id).copied() else {
let Some(fresh_ty) = self.state.map.of_let_name.get(&let_name_id).copied() else {
unreachable!("impossible:");
};
let let_name = &self.arena[let_name_id];
Expand Down Expand Up @@ -262,7 +258,7 @@ impl InferContext<'_> {
Expr::Constructor(_) => {
if let Some(constructor) = self.resolve.per_constructor_expr.get(&expr_id) {
if let Some(constructor_ty) =
self.state.infer_map.of_constructor.get(&constructor.constructor_id)
self.state.map.of_constructor.get(&constructor.constructor_id)
{
*constructor_ty
} else {
Expand Down Expand Up @@ -312,20 +308,20 @@ impl InferContext<'_> {
if let Some(variable) = self.resolve.per_variable_expr.get(&expr_id) {
let resolved_ty = match variable {
VariableResolution::Binder(binder_id) => {
self.state.infer_map.of_binder.get(binder_id).copied()
self.state.map.of_binder.get(binder_id).copied()
}
VariableResolution::Imported(InFile { file_id, value }) => {
if let Some(result) = self.imported.get(file_id) {
result.of_value_group.get(value).copied()
result.map.of_value_group.get(value).copied()
} else {
None
}
}
VariableResolution::LetName(let_id) => {
self.state.infer_map.of_let_name.get(let_id).copied()
self.state.map.of_let_name.get(let_id).copied()
}
VariableResolution::Local(value_id) => {
self.state.infer_map.of_value_group.get(value_id).copied()
self.state.map.of_value_group.get(value_id).copied()
}
};
resolved_ty.unwrap_or_else(|| db.intern_type(CoreType::NotImplemented))
Expand All @@ -335,7 +331,7 @@ impl InferContext<'_> {
}
Expr::NotImplemented => db.intern_type(CoreType::NotImplemented),
};
self.state.infer_map.of_expr.insert(expr_id, expr_ty);
self.state.map.of_expr.insert(expr_id, expr_ty);
expr_ty
}
}
Expand All @@ -348,14 +344,11 @@ impl InferContext<'_> {
expected_ty: CoreTypeId,
) {
let assign_expected = |this: &mut Self| {
this.state.infer_map.of_binder.insert(binder_id, expected_ty);
this.state.map.of_binder.insert(binder_id, expected_ty);
};

let assign_error = |this: &mut Self| {
this.state
.infer_map
.of_binder
.insert(binder_id, db.intern_type(CoreType::NotImplemented));
this.state.map.of_binder.insert(binder_id, db.intern_type(CoreType::NotImplemented));
};

let check_literal = |this: &mut Self, name: &str| {
Expand Down Expand Up @@ -383,7 +376,7 @@ impl InferContext<'_> {
}

if let Some(constructor_ty) =
self.state.infer_map.of_constructor.get(&constructor.constructor_id)
self.state.map.of_constructor.get(&constructor.constructor_id)
{
let constructor_ty = self.instantiate_type(db, *constructor_ty);
let (arguments_ty, result_ty) = self.peel_arguments(db, constructor_ty);
Expand All @@ -393,7 +386,7 @@ impl InferContext<'_> {
}

self.unify_types(db, result_ty, expected_ty);
self.state.infer_map.of_binder.insert(binder_id, result_ty);
self.state.map.of_binder.insert(binder_id, result_ty);
} else {
assign_error(self);
}
Expand Down Expand Up @@ -450,14 +443,14 @@ impl InferContext<'_> {
expected_ty: CoreTypeId,
) {
let assign_error = |this: &mut Self| {
this.state.infer_map.of_expr.insert(expr_id, db.intern_type(CoreType::NotImplemented));
this.state.map.of_expr.insert(expr_id, db.intern_type(CoreType::NotImplemented));
};

let check_literal = |this: &mut Self, name: &str| {
let name = Name::from_raw(db.interner().intern(name));
if let CoreType::Primitive(primitive) = db.lookup_intern_type(expected_ty) {
if primitive == name {
this.state.infer_map.of_expr.insert(expr_id, expected_ty);
this.state.map.of_expr.insert(expr_id, expected_ty);
} else {
assign_error(this);
}
Expand All @@ -468,15 +461,11 @@ impl InferContext<'_> {
Expr::Application(_, _) => todo!("check_expr(Application)"),
Expr::Constructor(_) => {
if let Some(constructor) = self.resolve.per_constructor_expr.get(&expr_id) {
if let Some(constructor_ty) = self
.state
.infer_map
.of_constructor
.get(&constructor.constructor_id)
.copied()
if let Some(constructor_ty) =
self.state.map.of_constructor.get(&constructor.constructor_id).copied()
{
self.subsume_types(db, constructor_ty, expected_ty);
self.state.infer_map.of_expr.insert(expr_id, expected_ty);
self.state.map.of_expr.insert(expr_id, expected_ty);
} else {
assign_error(self);
}
Expand All @@ -502,33 +491,33 @@ impl InferContext<'_> {
if let Some(variable) = self.resolve.per_variable_expr.get(&expr_id) {
let variable_ty = match variable {
VariableResolution::Binder(binder_id) => {
self.state.infer_map.of_binder.get(binder_id)
self.state.map.of_binder.get(binder_id)
}
VariableResolution::Imported(InFile { file_id, value }) => {
if let Some(result) = self.imported.get(file_id) {
result.of_value_group.get(value)
result.map.of_value_group.get(value)
} else {
None
}
}
VariableResolution::LetName(let_id) => {
self.state.infer_map.of_let_name.get(let_id)
self.state.map.of_let_name.get(let_id)
}
VariableResolution::Local(local_id) => {
self.state.infer_map.of_value_group.get(local_id)
self.state.map.of_value_group.get(local_id)
}
};
let variable_ty = variable_ty
.copied()
.unwrap_or_else(|| db.intern_type(CoreType::NotImplemented));
self.subsume_types(db, variable_ty, expected_ty);
self.state.infer_map.of_expr.insert(expr_id, expected_ty);
self.state.map.of_expr.insert(expr_id, expected_ty);
} else {
assign_error(self);
}
}
Expr::NotImplemented => {
self.state.infer_map.of_expr.insert(expr_id, expected_ty);
self.state.map.of_expr.insert(expr_id, expected_ty);
}
}
}
Expand Down

0 comments on commit 48deef2

Please sign in to comment.