From 131683fe0596316d98d32bcd5edd726c096f9c2f Mon Sep 17 00:00:00 2001 From: Justin Garcia Date: Sun, 11 Feb 2024 03:12:33 +0800 Subject: [PATCH] Move recursive group analysis logic to a separate module --- crates/analyzer-module/src/infer/rules.rs | 103 +++---------- .../src/infer/rules/recursive.rs | 145 ++++++++++++++++++ 2 files changed, 167 insertions(+), 81 deletions(-) create mode 100644 crates/analyzer-module/src/infer/rules/recursive.rs diff --git a/crates/analyzer-module/src/infer/rules.rs b/crates/analyzer-module/src/infer/rules.rs index 59161a6e..b12c9e93 100644 --- a/crates/analyzer-module/src/infer/rules.rs +++ b/crates/analyzer-module/src/infer/rules.rs @@ -1,67 +1,24 @@ //! Implements inference rules. +mod recursive; + use std::sync::Arc; use files::FileId; -use petgraph::{algo::kosaraju_scc, graphmap::DiGraphMap}; + use rustc_hash::{FxHashMap, FxHashSet}; -use crate::{ +pub(self) use crate::{ id::InFile, - index::nominal::DataGroupId, infer::pretty_print, scope::{ResolveInfo, TypeConstructorKind}, surface::tree::*, InferenceDatabase, }; -use super::{CoreType, CoreTypeId, InferenceResult}; - -// region: Recursive Binding Groups - -#[derive(Debug)] -struct RecursiveGroupBuilder<'a> { - resolve_info: &'a ResolveInfo, - type_graph: DiGraphMap, -} +use self::recursive::{recursive_data_groups, recursive_value_groups}; -impl<'a> RecursiveGroupBuilder<'a> { - fn new(resolve_info: &'a ResolveInfo) -> RecursiveGroupBuilder<'a> { - let type_graph = DiGraphMap::default(); - RecursiveGroupBuilder { resolve_info, type_graph } - } - - fn analyze_type(&mut self, data_id: DataGroupId, arena: &SurfaceArena, type_id: TypeId) { - match &arena[type_id] { - Type::Arrow(arguments, result) => { - for argument in arguments { - self.analyze_type(data_id, arena, *argument); - } - self.analyze_type(data_id, arena, *result); - } - Type::Application(function, arguments) => { - self.analyze_type(data_id, arena, *function); - for argument in arguments { - self.analyze_type(data_id, arena, *argument); - } - } - Type::Constructor(_) => { - if let Some(type_constructor) = self.resolve_info.per_type_type.get(&type_id) { - let dependent = TypeConstructorKind::Data(data_id); - let dependency = type_constructor.kind; - self.type_graph.add_edge(dependent, dependency, ()); - } - } - Type::Parenthesized(parenthesized) => { - self.analyze_type(data_id, arena, *parenthesized); - } - Type::Variable(_) => (), - Type::NotImplemented => (), - } - } -} - -// endregion +pub(self) use super::{CoreType, CoreTypeId, InferenceResult}; // region: Type Inference Rules @@ -357,45 +314,29 @@ pub(super) fn file_infer_query( let (surface, arena) = db.file_surface(file_id); let resolve = db.file_resolve(file_id); - let mut builder = RecursiveGroupBuilder::new(&resolve); - surface.body.iter_data_declarations().for_each(|data_declaration| { - builder.type_graph.add_node(TypeConstructorKind::Data(data_declaration.id)); - data_declaration.constructors.values().for_each(|data_constructor| { - data_constructor.fields.iter().for_each(|field| { - builder.analyze_type(data_declaration.id, &arena, *field); - }); - }); - }); - let mut ctx = InferContext::new(file_id, &arena, &resolve); - for components in kosaraju_scc(&builder.type_graph) { - for TypeConstructorKind::Data(data_group_id) in components { - let index = surface.body.data_declarations.get(&data_group_id).unwrap_or_else(|| { - unreachable!("impossible: data_group_id comes from iter_data_declarations"); - }); - let Declaration::DataDeclaration(data_declaration) = &surface.body.declarations[*index] - else { - unreachable!("impossible: an invalid index was set to data_declarations"); + let recursive_data = + recursive_data_groups(&arena, &resolve, surface.body.iter_data_declarations()); + for recursive_group in recursive_data { + for data_group_id in recursive_group { + let Some(data_declaration) = surface.body.data_declaration(data_group_id) else { + unreachable!("impossible: unknown data_group_id"); }; infer_data_declaration(&mut ctx, db, file_id, data_declaration); } } - surface - .body - .declarations - .iter() - .filter_map(|declaration| { - if let Declaration::ValueDeclaration(value_declaration) = declaration { - Some(value_declaration) - } else { - None - } - }) - .for_each(|value_declaration| { - infer_value_declaration(&mut ctx, db, &value_declaration); - }); + let recursive_value = + recursive_value_groups(&arena, &resolve, surface.body.iter_value_declarations()); + for recursive_group in recursive_value { + for value_group_id in recursive_group { + let Some(value_declaration) = surface.body.value_declaration(value_group_id) else { + unreachable!("impossible: unknown value_group_id"); + }; + infer_value_declaration(&mut ctx, db, value_declaration); + } + } Arc::new(ctx.result) } diff --git a/crates/analyzer-module/src/infer/rules/recursive.rs b/crates/analyzer-module/src/infer/rules/recursive.rs new file mode 100644 index 00000000..0b6484d6 --- /dev/null +++ b/crates/analyzer-module/src/infer/rules/recursive.rs @@ -0,0 +1,145 @@ +//! Implements grouping for recursive declarations in a module. + +use itertools::Itertools; +use petgraph::{algo::kosaraju_scc, graphmap::DiGraphMap}; + +use crate::{ + index::nominal::{DataGroupId, ValueGroupId}, + scope::{ResolveInfo, TypeConstructorKind, VariableResolution}, + surface::{tree::*, visit::*}, +}; + +struct AnalyzeRecursiveGroupCtx<'ast, 'env> { + arena: &'ast SurfaceArena, + resolve: &'env ResolveInfo, + dependent: Option, + graph: DiGraphMap, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +enum NodeKind { + DataGroupId(DataGroupId), + ValueGroupId(ValueGroupId), +} + +impl<'ast, 'env> AnalyzeRecursiveGroupCtx<'ast, 'env> { + fn new( + arena: &'ast SurfaceArena, + resolve: &'env ResolveInfo, + ) -> AnalyzeRecursiveGroupCtx<'ast, 'env> { + let dependent = None; + let graph = DiGraphMap::default(); + AnalyzeRecursiveGroupCtx { arena, resolve, dependent, graph } + } + + fn with_dependent(&mut self, dependent: NodeKind) { + self.graph.add_node(dependent); + self.dependent = Some(dependent); + } +} + +impl<'ast> Visitor<'ast> for AnalyzeRecursiveGroupCtx<'ast, '_> { + fn arena(&self) -> &'ast SurfaceArena { + self.arena + } + + fn visit_expr(&mut self, expr_id: ExprId) { + let Some(dependent) = self.dependent else { + unreachable!("impossible: dependent is unset!"); + }; + match &self.arena[expr_id] { + Expr::Variable(_) => { + if let Some(VariableResolution::Local(value_id)) = + self.resolve.per_variable_expr.get(&expr_id) + { + let dependency = NodeKind::ValueGroupId(*value_id); + self.graph.add_edge(dependent, dependency, ()); + } + } + _ => default_visit_expr(self, expr_id), + } + } + + fn visit_type(&mut self, type_id: TypeId) { + let Some(dependent) = self.dependent else { + unreachable!("impossible: dependent is unset!"); + }; + match &self.arena[type_id] { + Type::Constructor(_) => { + if let Some(type_constructor) = self.resolve.per_type_type.get(&type_id) { + let dependency = match type_constructor.kind { + TypeConstructorKind::Data(data_id) => NodeKind::DataGroupId(data_id), + }; + self.graph.add_edge(dependent, dependency, ()); + } + } + _ => default_visit_type(self, type_id), + } + } +} + +pub(super) fn recursive_data_groups<'ast, 'env>( + arena: &'ast SurfaceArena, + resolve: &'env ResolveInfo, + data_declarations: impl Iterator, +) -> Vec> { + let mut ctx = AnalyzeRecursiveGroupCtx::new(arena, resolve); + for data_declaration in data_declarations { + ctx.with_dependent(NodeKind::DataGroupId(data_declaration.id)); + for (_, data_constructor) in &data_declaration.constructors { + for field in &data_constructor.fields { + ctx.visit_type(*field); + } + } + } + kosaraju_scc(&ctx.graph) + .into_iter() + .map(|components| { + components + .into_iter() + .map(|node_kind| { + if let NodeKind::DataGroupId(data_group_id) = node_kind { + data_group_id + } else { + unreachable!("impossible: invalid node_kind!") + } + }) + .collect_vec() + }) + .collect_vec() +} + +pub(super) fn recursive_value_groups<'ast, 'env>( + arena: &'ast SurfaceArena, + resolve: &'env ResolveInfo, + value_declarations: impl Iterator, +) -> Vec> { + let mut ctx = AnalyzeRecursiveGroupCtx::new(arena, resolve); + for value_declaration in value_declarations { + ctx.with_dependent(NodeKind::ValueGroupId(value_declaration.id)); + // TODO: Should this be a visitor method instead? + for equation in &value_declaration.equations { + match &equation.binding { + Binding::Unconditional { where_expr } => { + ctx.visit_let_bindings(&where_expr.let_bindings); + ctx.visit_expr(where_expr.expr_id); + } + } + } + } + kosaraju_scc(&ctx.graph) + .into_iter() + .map(|components| { + components + .into_iter() + .map(|node_kind| { + if let NodeKind::ValueGroupId(value_group_id) = node_kind { + value_group_id + } else { + unreachable!("impossible: invalid node_kind!") + } + }) + .collect_vec() + }) + .collect_vec() +}