diff --git a/crates/analyzer-module/src/infer.rs b/crates/analyzer-module/src/infer.rs index a295eb9d..cc5fff99 100644 --- a/crates/analyzer-module/src/infer.rs +++ b/crates/analyzer-module/src/infer.rs @@ -34,11 +34,18 @@ pub struct InferMap { pub of_value_group: FxHashMap, } +#[derive(Debug, PartialEq, Eq)] +pub struct InferResult { + pub constraints: Vec, + pub errors: Vec, + pub map: InferMap, +} + #[salsa::query_group(InferenceStorage)] pub trait InferenceDatabase: ScopeDatabase { #[salsa::interned] fn intern_type(&self, t: CoreType) -> CoreTypeId; #[salsa::invoke(rules::file_infer_query)] - fn file_infer(&self, file_id: FileId) -> Arc; + fn file_infer(&self, file_id: FileId) -> Arc; } diff --git a/crates/analyzer-module/src/infer/rules.rs b/crates/analyzer-module/src/infer/rules.rs index 80e9c856..71406767 100644 --- a/crates/analyzer-module/src/infer/rules.rs +++ b/crates/analyzer-module/src/infer/rules.rs @@ -16,7 +16,7 @@ use crate::{ id::InFile, infer::pretty_print, scope::ResolveInfo, surface::tree::*, InferenceDatabase, }; -use super::{Constraint, CoreType, CoreTypeId, Hint, InferError, InferMap}; +use super::{Constraint, CoreType, CoreTypeId, Hint, InferError, InferMap, InferResult}; use recursive::{recursive_data_groups, recursive_let_names, recursive_value_groups}; @@ -26,14 +26,14 @@ struct InferState { hints: Vec, constraints: Vec, errors: Vec, - infer_map: InferMap, + map: InferMap, } struct InferContext<'a> { file_id: FileId, arena: &'a SurfaceArena, resolve: &'a ResolveInfo, - imported: &'a FxHashMap>, + imported: &'a FxHashMap>, state: InferState, } @@ -42,7 +42,7 @@ impl<'a> InferContext<'a> { file_id: FileId, arena: &'a SurfaceArena, resolve: &'a ResolveInfo, - imported: &'a FxHashMap>, + imported: &'a FxHashMap>, ) -> InferContext<'a> { let state = InferState::default(); InferContext { file_id, arena, resolve, state, imported } @@ -89,7 +89,7 @@ impl<'i, 'a> SolveContext<'i, 'a> { } } -pub(super) fn file_infer_query(db: &dyn InferenceDatabase, file_id: FileId) -> Arc { +pub(super) fn file_infer_query(db: &dyn InferenceDatabase, file_id: FileId) -> Arc { let (surface, arena) = db.file_surface(file_id); let resolve = db.file_resolve(file_id); @@ -132,5 +132,9 @@ pub(super) fn file_infer_query(db: &dyn InferenceDatabase, file_id: FileId) -> A eprintln!("{} ~ {}", u.value, pretty_print(db, t)); } - Arc::new(infer_ctx.state.infer_map) + Arc::new(InferResult { + constraints: infer_ctx.state.constraints, + errors: infer_ctx.state.errors, + map: infer_ctx.state.map, + }) } diff --git a/crates/analyzer-module/src/infer/rules/data.rs b/crates/analyzer-module/src/infer/rules/data.rs index d3f21a48..686af59c 100644 --- a/crates/analyzer-module/src/infer/rules/data.rs +++ b/crates/analyzer-module/src/infer/rules/data.rs @@ -48,7 +48,7 @@ impl InferContext<'_> { }, ); - self.state.infer_map.of_constructor.insert(*constructor_id, qualified_ty); + self.state.map.of_constructor.insert(*constructor_id, qualified_ty); }); } } diff --git a/crates/analyzer-module/src/infer/rules/value.rs b/crates/analyzer-module/src/infer/rules/value.rs index 324234c3..782ae453 100644 --- a/crates/analyzer-module/src/infer/rules/value.rs +++ b/crates/analyzer-module/src/infer/rules/value.rs @@ -19,7 +19,7 @@ impl InferContext<'_> { } else { self.fresh_unification(db) }; - self.state.infer_map.of_value_group.insert(*value_group_id, value_ty); + self.state.map.of_value_group.insert(*value_group_id, value_ty); } for (value_group_id, value_declaration) in value_declarations { @@ -34,8 +34,7 @@ impl InferContext<'_> { value_declaration: &ValueDeclaration, ) { self.add_hint(Hint::ValueGroup(value_group_id)); - let Some(value_ty) = self.state.infer_map.of_value_group.get(&value_group_id).copied() - else { + let Some(value_ty) = self.state.map.of_value_group.get(&value_group_id).copied() else { unreachable!("impossible: caller must insert a type!"); }; @@ -88,11 +87,8 @@ impl InferContext<'_> { if let Some(constructor_resolution) = self.resolve.per_constructor_binder.get(&binder_id) { - if let Some(constructor_ty) = self - .state - .infer_map - .of_constructor - .get(&constructor_resolution.constructor_id) + if let Some(constructor_ty) = + self.state.map.of_constructor.get(&constructor_resolution.constructor_id) { let constructor_ty = self.instantiate_type(db, *constructor_ty); let (arguments_ty, _) = self.peel_arguments(db, constructor_ty); @@ -146,7 +142,7 @@ impl InferContext<'_> { Binder::Wildcard => self.fresh_unification(db), Binder::NotImplemented => db.intern_type(CoreType::NotImplemented), }; - self.state.infer_map.of_binder.insert(binder_id, binder_ty); + self.state.map.of_binder.insert(binder_id, binder_ty); binder_ty } @@ -191,7 +187,7 @@ impl InferContext<'_> { for let_name_component in let_name_components { for let_name_id in &let_name_component { let fresh_ty = self.fresh_unification(db); - self.state.infer_map.of_let_name.insert(*let_name_id, fresh_ty); + self.state.map.of_let_name.insert(*let_name_id, fresh_ty); } for let_name_id in let_name_component { self.infer_let_name(db, let_name_id); @@ -211,7 +207,7 @@ impl InferContext<'_> { } fn infer_let_name(&mut self, db: &dyn InferenceDatabase, let_name_id: LetNameId) { - let Some(fresh_ty) = self.state.infer_map.of_let_name.get(&let_name_id).copied() else { + let Some(fresh_ty) = self.state.map.of_let_name.get(&let_name_id).copied() else { unreachable!("impossible:"); }; let let_name = &self.arena[let_name_id]; @@ -262,7 +258,7 @@ impl InferContext<'_> { Expr::Constructor(_) => { if let Some(constructor) = self.resolve.per_constructor_expr.get(&expr_id) { if let Some(constructor_ty) = - self.state.infer_map.of_constructor.get(&constructor.constructor_id) + self.state.map.of_constructor.get(&constructor.constructor_id) { *constructor_ty } else { @@ -312,20 +308,20 @@ impl InferContext<'_> { if let Some(variable) = self.resolve.per_variable_expr.get(&expr_id) { let resolved_ty = match variable { VariableResolution::Binder(binder_id) => { - self.state.infer_map.of_binder.get(binder_id).copied() + self.state.map.of_binder.get(binder_id).copied() } VariableResolution::Imported(InFile { file_id, value }) => { if let Some(result) = self.imported.get(file_id) { - result.of_value_group.get(value).copied() + result.map.of_value_group.get(value).copied() } else { None } } VariableResolution::LetName(let_id) => { - self.state.infer_map.of_let_name.get(let_id).copied() + self.state.map.of_let_name.get(let_id).copied() } VariableResolution::Local(value_id) => { - self.state.infer_map.of_value_group.get(value_id).copied() + self.state.map.of_value_group.get(value_id).copied() } }; resolved_ty.unwrap_or_else(|| db.intern_type(CoreType::NotImplemented)) @@ -335,7 +331,7 @@ impl InferContext<'_> { } Expr::NotImplemented => db.intern_type(CoreType::NotImplemented), }; - self.state.infer_map.of_expr.insert(expr_id, expr_ty); + self.state.map.of_expr.insert(expr_id, expr_ty); expr_ty } } @@ -348,14 +344,11 @@ impl InferContext<'_> { expected_ty: CoreTypeId, ) { let assign_expected = |this: &mut Self| { - this.state.infer_map.of_binder.insert(binder_id, expected_ty); + this.state.map.of_binder.insert(binder_id, expected_ty); }; let assign_error = |this: &mut Self| { - this.state - .infer_map - .of_binder - .insert(binder_id, db.intern_type(CoreType::NotImplemented)); + this.state.map.of_binder.insert(binder_id, db.intern_type(CoreType::NotImplemented)); }; let check_literal = |this: &mut Self, name: &str| { @@ -383,7 +376,7 @@ impl InferContext<'_> { } if let Some(constructor_ty) = - self.state.infer_map.of_constructor.get(&constructor.constructor_id) + self.state.map.of_constructor.get(&constructor.constructor_id) { let constructor_ty = self.instantiate_type(db, *constructor_ty); let (arguments_ty, result_ty) = self.peel_arguments(db, constructor_ty); @@ -393,7 +386,7 @@ impl InferContext<'_> { } self.unify_types(db, result_ty, expected_ty); - self.state.infer_map.of_binder.insert(binder_id, result_ty); + self.state.map.of_binder.insert(binder_id, result_ty); } else { assign_error(self); } @@ -450,14 +443,14 @@ impl InferContext<'_> { expected_ty: CoreTypeId, ) { let assign_error = |this: &mut Self| { - this.state.infer_map.of_expr.insert(expr_id, db.intern_type(CoreType::NotImplemented)); + this.state.map.of_expr.insert(expr_id, db.intern_type(CoreType::NotImplemented)); }; let check_literal = |this: &mut Self, name: &str| { let name = Name::from_raw(db.interner().intern(name)); if let CoreType::Primitive(primitive) = db.lookup_intern_type(expected_ty) { if primitive == name { - this.state.infer_map.of_expr.insert(expr_id, expected_ty); + this.state.map.of_expr.insert(expr_id, expected_ty); } else { assign_error(this); } @@ -468,15 +461,11 @@ impl InferContext<'_> { Expr::Application(_, _) => todo!("check_expr(Application)"), Expr::Constructor(_) => { if let Some(constructor) = self.resolve.per_constructor_expr.get(&expr_id) { - if let Some(constructor_ty) = self - .state - .infer_map - .of_constructor - .get(&constructor.constructor_id) - .copied() + if let Some(constructor_ty) = + self.state.map.of_constructor.get(&constructor.constructor_id).copied() { self.subsume_types(db, constructor_ty, expected_ty); - self.state.infer_map.of_expr.insert(expr_id, expected_ty); + self.state.map.of_expr.insert(expr_id, expected_ty); } else { assign_error(self); } @@ -502,33 +491,33 @@ impl InferContext<'_> { if let Some(variable) = self.resolve.per_variable_expr.get(&expr_id) { let variable_ty = match variable { VariableResolution::Binder(binder_id) => { - self.state.infer_map.of_binder.get(binder_id) + self.state.map.of_binder.get(binder_id) } VariableResolution::Imported(InFile { file_id, value }) => { if let Some(result) = self.imported.get(file_id) { - result.of_value_group.get(value) + result.map.of_value_group.get(value) } else { None } } VariableResolution::LetName(let_id) => { - self.state.infer_map.of_let_name.get(let_id) + self.state.map.of_let_name.get(let_id) } VariableResolution::Local(local_id) => { - self.state.infer_map.of_value_group.get(local_id) + self.state.map.of_value_group.get(local_id) } }; let variable_ty = variable_ty .copied() .unwrap_or_else(|| db.intern_type(CoreType::NotImplemented)); self.subsume_types(db, variable_ty, expected_ty); - self.state.infer_map.of_expr.insert(expr_id, expected_ty); + self.state.map.of_expr.insert(expr_id, expected_ty); } else { assign_error(self); } } Expr::NotImplemented => { - self.state.infer_map.of_expr.insert(expr_id, expected_ty); + self.state.map.of_expr.insert(expr_id, expected_ty); } } }