Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] initial fixpoint iteration #14029

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ rand = { version = "0.8.5" }
rayon = { version = "1.10.0" }
regex = { version = "1.10.2" }
rustc-hash = { version = "2.0.0" }
salsa = { git = "https://github.com/salsa-rs/salsa.git", rev = "254c749b02cde2fd29852a7463a33e800b771758" }
salsa = { git = "https://github.com/salsa-rs/salsa.git", rev = "c1bbdcff28c2675f622d7e7fe10f5a0ca073f221" }
schemars = { version = "0.8.16" }
seahash = { version = "4.1.0" }
serde = { version = "1.0.197", features = ["derive"] }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ impl<'db> Definition<'db> {
self.file_scope(db).to_scope_id(db, self.file(db))
}

#[allow(unused)]
pub(crate) fn category(self, db: &'db dyn Db) -> DefinitionCategory {
self.kind(db).category()
}
Expand Down
72 changes: 42 additions & 30 deletions crates/red_knot_python_semantic/src/types/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,30 +78,9 @@ pub(crate) fn infer_scope_types<'db>(db: &'db dyn Db, scope: ScopeId<'db>) -> Ty
TypeInferenceBuilder::new(db, InferenceRegion::Scope(scope), index).finish()
}

/// Cycle recovery for [`infer_definition_types()`]: for now, just [`Type::Unknown`]
/// TODO fixpoint iteration
fn infer_definition_types_cycle_recovery<'db>(
db: &'db dyn Db,
_cycle: &salsa::Cycle,
input: Definition<'db>,
) -> TypeInference<'db> {
tracing::trace!("infer_definition_types_cycle_recovery");
let mut inference = TypeInference::empty(input.scope(db));
let category = input.category(db);
if category.is_declaration() {
inference.declarations.insert(input, Type::Unknown);
}
if category.is_binding() {
inference.bindings.insert(input, Type::Unknown);
}
// TODO we don't fill in expression types for the cycle-participant definitions, which can
// later cause a panic when looking up an expression type.
inference
}

/// Infer all types for a [`Definition`] (including sub-expressions).
/// Use when resolving a symbol name use or public type of a symbol.
#[salsa::tracked(return_ref, recovery_fn=infer_definition_types_cycle_recovery)]
#[salsa::tracked(return_ref, cycle_fn=cycle_recover, cycle_initial=cycle_initial)]
pub(crate) fn infer_definition_types<'db>(
db: &'db dyn Db,
definition: Definition<'db>,
Expand All @@ -119,6 +98,20 @@ pub(crate) fn infer_definition_types<'db>(
TypeInferenceBuilder::new(db, InferenceRegion::Definition(definition), index).finish()
}

fn cycle_recover<'db>(
_db: &'db dyn Db,
_value: &TypeInference<'db>,
count: u32,
_definition: Definition<'db>,
) -> salsa::CycleRecoveryAction<TypeInference<'db>> {
assert!(count < 10, "cycle did not converge within 10 iterations");
salsa::CycleRecoveryAction::Iterate
}

fn cycle_initial<'db>(db: &'db dyn Db, definition: Definition<'db>) -> TypeInference<'db> {
TypeInference::empty(definition.scope(db), Some(Type::Never))
}

/// Infer types for all deferred type expressions in a [`Definition`].
///
/// Deferred expressions are type expressions (annotations, base classes, aliases...) in a stub
Expand Down Expand Up @@ -191,25 +184,33 @@ pub(crate) struct TypeInference<'db> {
/// Are there deferred type expressions in this region?
has_deferred: bool,

/// The scope belong to this region.
/// The scope this region is part of.
scope: ScopeId<'db>,

/// The fallback type for all expressions/bindings/declarations.
fallback_ty: Option<Type<'db>>,
}

impl<'db> TypeInference<'db> {
pub(crate) fn empty(scope: ScopeId<'db>) -> Self {
pub(crate) fn empty(scope: ScopeId<'db>, fallback_ty: Option<Type<'db>>) -> Self {
Self {
expressions: FxHashMap::default(),
bindings: FxHashMap::default(),
declarations: FxHashMap::default(),
diagnostics: TypeCheckDiagnostics::default(),
has_deferred: false,
scope,
fallback_ty,
}
}

#[track_caller]
pub(crate) fn expression_ty(&self, expression: ScopedExpressionId) -> Type<'db> {
self.expressions[&expression]
if let Some(fallback) = self.fallback_ty {
self.try_expression_ty(expression).unwrap_or(fallback)
} else {
self.expressions[&expression]
}
}

pub(crate) fn try_expression_ty(&self, expression: ScopedExpressionId) -> Option<Type<'db>> {
Expand All @@ -218,12 +219,23 @@ impl<'db> TypeInference<'db> {

#[track_caller]
pub(crate) fn binding_ty(&self, definition: Definition<'db>) -> Type<'db> {
self.bindings[&definition]
if let Some(fallback) = self.fallback_ty {
self.bindings.get(&definition).copied().unwrap_or(fallback)
} else {
self.bindings[&definition]
}
}

#[track_caller]
pub(crate) fn declaration_ty(&self, definition: Definition<'db>) -> Type<'db> {
self.declarations[&definition]
if let Some(fallback) = self.fallback_ty {
self.declarations
.get(&definition)
.copied()
.unwrap_or(fallback)
} else {
self.declarations[&definition]
}
}

pub(crate) fn diagnostics(&self) -> &[std::sync::Arc<TypeCheckDiagnostic>] {
Expand Down Expand Up @@ -324,7 +336,7 @@ impl<'db> TypeInferenceBuilder<'db> {
index,
region,
file,
types: TypeInference::empty(scope),
types: TypeInference::empty(scope, None),
diagnostics: TypeCheckDiagnosticsBuilder::new(db, file),
}
}
Expand Down Expand Up @@ -4530,8 +4542,8 @@ mod tests {
",
)?;

// TODO: sys.version_info, and need to understand @final and @type_check_only
assert_public_ty(&db, "src/a.py", "x", "EllipsisType | Unknown");
// TODO: sys.version_info
assert_public_ty(&db, "src/a.py", "x", "EllipsisType | ellipsis");

Ok(())
}
Expand Down
Loading