Skip to content

Commit

Permalink
Move recursive group analysis logic to a separate module
Browse files Browse the repository at this point in the history
  • Loading branch information
purefunctor committed Feb 10, 2024
1 parent e285440 commit 131683f
Show file tree
Hide file tree
Showing 2 changed files with 167 additions and 81 deletions.
103 changes: 22 additions & 81 deletions crates/analyzer-module/src/infer/rules.rs
Original file line number Diff line number Diff line change
@@ -1,67 +1,24 @@
//! Implements inference rules.
mod recursive;

use std::sync::Arc;

use files::FileId;
use petgraph::{algo::kosaraju_scc, graphmap::DiGraphMap};

use rustc_hash::{FxHashMap, FxHashSet};

use crate::{
pub(self) use crate::{
id::InFile,
index::nominal::DataGroupId,
infer::pretty_print,
scope::{ResolveInfo, TypeConstructorKind},
surface::tree::*,
InferenceDatabase,
};

use super::{CoreType, CoreTypeId, InferenceResult};

// region: Recursive Binding Groups

#[derive(Debug)]
struct RecursiveGroupBuilder<'a> {
resolve_info: &'a ResolveInfo,
type_graph: DiGraphMap<TypeConstructorKind, ()>,
}
use self::recursive::{recursive_data_groups, recursive_value_groups};

impl<'a> RecursiveGroupBuilder<'a> {
fn new(resolve_info: &'a ResolveInfo) -> RecursiveGroupBuilder<'a> {
let type_graph = DiGraphMap::default();
RecursiveGroupBuilder { resolve_info, type_graph }
}

fn analyze_type(&mut self, data_id: DataGroupId, arena: &SurfaceArena, type_id: TypeId) {
match &arena[type_id] {
Type::Arrow(arguments, result) => {
for argument in arguments {
self.analyze_type(data_id, arena, *argument);
}
self.analyze_type(data_id, arena, *result);
}
Type::Application(function, arguments) => {
self.analyze_type(data_id, arena, *function);
for argument in arguments {
self.analyze_type(data_id, arena, *argument);
}
}
Type::Constructor(_) => {
if let Some(type_constructor) = self.resolve_info.per_type_type.get(&type_id) {
let dependent = TypeConstructorKind::Data(data_id);
let dependency = type_constructor.kind;
self.type_graph.add_edge(dependent, dependency, ());
}
}
Type::Parenthesized(parenthesized) => {
self.analyze_type(data_id, arena, *parenthesized);
}
Type::Variable(_) => (),
Type::NotImplemented => (),
}
}
}

// endregion
pub(self) use super::{CoreType, CoreTypeId, InferenceResult};

// region: Type Inference Rules

Expand Down Expand Up @@ -357,45 +314,29 @@ pub(super) fn file_infer_query(
let (surface, arena) = db.file_surface(file_id);
let resolve = db.file_resolve(file_id);

let mut builder = RecursiveGroupBuilder::new(&resolve);
surface.body.iter_data_declarations().for_each(|data_declaration| {
builder.type_graph.add_node(TypeConstructorKind::Data(data_declaration.id));
data_declaration.constructors.values().for_each(|data_constructor| {
data_constructor.fields.iter().for_each(|field| {
builder.analyze_type(data_declaration.id, &arena, *field);
});
});
});

let mut ctx = InferContext::new(file_id, &arena, &resolve);

for components in kosaraju_scc(&builder.type_graph) {
for TypeConstructorKind::Data(data_group_id) in components {
let index = surface.body.data_declarations.get(&data_group_id).unwrap_or_else(|| {
unreachable!("impossible: data_group_id comes from iter_data_declarations");
});
let Declaration::DataDeclaration(data_declaration) = &surface.body.declarations[*index]
else {
unreachable!("impossible: an invalid index was set to data_declarations");
let recursive_data =
recursive_data_groups(&arena, &resolve, surface.body.iter_data_declarations());
for recursive_group in recursive_data {
for data_group_id in recursive_group {
let Some(data_declaration) = surface.body.data_declaration(data_group_id) else {
unreachable!("impossible: unknown data_group_id");
};
infer_data_declaration(&mut ctx, db, file_id, data_declaration);
}
}

surface
.body
.declarations
.iter()
.filter_map(|declaration| {
if let Declaration::ValueDeclaration(value_declaration) = declaration {
Some(value_declaration)
} else {
None
}
})
.for_each(|value_declaration| {
infer_value_declaration(&mut ctx, db, &value_declaration);
});
let recursive_value =
recursive_value_groups(&arena, &resolve, surface.body.iter_value_declarations());
for recursive_group in recursive_value {
for value_group_id in recursive_group {
let Some(value_declaration) = surface.body.value_declaration(value_group_id) else {
unreachable!("impossible: unknown value_group_id");
};
infer_value_declaration(&mut ctx, db, value_declaration);
}
}

Arc::new(ctx.result)
}
145 changes: 145 additions & 0 deletions crates/analyzer-module/src/infer/rules/recursive.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
//! Implements grouping for recursive declarations in a module.
use itertools::Itertools;
use petgraph::{algo::kosaraju_scc, graphmap::DiGraphMap};

use crate::{
index::nominal::{DataGroupId, ValueGroupId},
scope::{ResolveInfo, TypeConstructorKind, VariableResolution},
surface::{tree::*, visit::*},
};

struct AnalyzeRecursiveGroupCtx<'ast, 'env> {
arena: &'ast SurfaceArena,
resolve: &'env ResolveInfo,
dependent: Option<NodeKind>,
graph: DiGraphMap<NodeKind, ()>,
}

#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
enum NodeKind {
DataGroupId(DataGroupId),
ValueGroupId(ValueGroupId),
}

impl<'ast, 'env> AnalyzeRecursiveGroupCtx<'ast, 'env> {
fn new(
arena: &'ast SurfaceArena,
resolve: &'env ResolveInfo,
) -> AnalyzeRecursiveGroupCtx<'ast, 'env> {
let dependent = None;
let graph = DiGraphMap::default();
AnalyzeRecursiveGroupCtx { arena, resolve, dependent, graph }
}

fn with_dependent(&mut self, dependent: NodeKind) {
self.graph.add_node(dependent);
self.dependent = Some(dependent);
}
}

impl<'ast> Visitor<'ast> for AnalyzeRecursiveGroupCtx<'ast, '_> {
fn arena(&self) -> &'ast SurfaceArena {
self.arena
}

fn visit_expr(&mut self, expr_id: ExprId) {
let Some(dependent) = self.dependent else {
unreachable!("impossible: dependent is unset!");
};
match &self.arena[expr_id] {
Expr::Variable(_) => {
if let Some(VariableResolution::Local(value_id)) =
self.resolve.per_variable_expr.get(&expr_id)
{
let dependency = NodeKind::ValueGroupId(*value_id);
self.graph.add_edge(dependent, dependency, ());
}
}
_ => default_visit_expr(self, expr_id),
}
}

fn visit_type(&mut self, type_id: TypeId) {
let Some(dependent) = self.dependent else {
unreachable!("impossible: dependent is unset!");
};
match &self.arena[type_id] {
Type::Constructor(_) => {
if let Some(type_constructor) = self.resolve.per_type_type.get(&type_id) {
let dependency = match type_constructor.kind {
TypeConstructorKind::Data(data_id) => NodeKind::DataGroupId(data_id),
};
self.graph.add_edge(dependent, dependency, ());
}
}
_ => default_visit_type(self, type_id),
}
}
}

pub(super) fn recursive_data_groups<'ast, 'env>(
arena: &'ast SurfaceArena,
resolve: &'env ResolveInfo,
data_declarations: impl Iterator<Item = &'ast DataDeclaration>,
) -> Vec<Vec<DataGroupId>> {
let mut ctx = AnalyzeRecursiveGroupCtx::new(arena, resolve);
for data_declaration in data_declarations {
ctx.with_dependent(NodeKind::DataGroupId(data_declaration.id));
for (_, data_constructor) in &data_declaration.constructors {
for field in &data_constructor.fields {
ctx.visit_type(*field);
}
}
}
kosaraju_scc(&ctx.graph)
.into_iter()
.map(|components| {
components
.into_iter()
.map(|node_kind| {
if let NodeKind::DataGroupId(data_group_id) = node_kind {
data_group_id
} else {
unreachable!("impossible: invalid node_kind!")
}
})
.collect_vec()
})
.collect_vec()
}

pub(super) fn recursive_value_groups<'ast, 'env>(
arena: &'ast SurfaceArena,
resolve: &'env ResolveInfo,
value_declarations: impl Iterator<Item = &'ast ValueDeclaration>,
) -> Vec<Vec<ValueGroupId>> {
let mut ctx = AnalyzeRecursiveGroupCtx::new(arena, resolve);
for value_declaration in value_declarations {
ctx.with_dependent(NodeKind::ValueGroupId(value_declaration.id));
// TODO: Should this be a visitor method instead?
for equation in &value_declaration.equations {
match &equation.binding {
Binding::Unconditional { where_expr } => {
ctx.visit_let_bindings(&where_expr.let_bindings);
ctx.visit_expr(where_expr.expr_id);
}
}
}
}
kosaraju_scc(&ctx.graph)
.into_iter()
.map(|components| {
components
.into_iter()
.map(|node_kind| {
if let NodeKind::ValueGroupId(value_group_id) = node_kind {
value_group_id
} else {
unreachable!("impossible: invalid node_kind!")
}
})
.collect_vec()
})
.collect_vec()
}

0 comments on commit 131683f

Please sign in to comment.