Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement trait upcasting #821

Merged
merged 7 commits into from
Jan 8, 2025
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
234 changes: 181 additions & 53 deletions chalk-solve/src/clauses/builtin_traits/unsize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use std::collections::HashSet;
use std::iter;
use std::ops::ControlFlow;

use crate::clauses::super_traits::super_traits;
use crate::clauses::ClauseBuilder;
use crate::rust_ir::AdtKind;
use crate::{Interner, RustIrDatabase, TraitRef, WellKnownTrait};
Expand Down Expand Up @@ -136,17 +137,27 @@ fn uses_outer_binder_params<I: Interner>(
matches!(flow, ControlFlow::Break(_))
}

fn principal_id<I: Interner>(
fn principal_trait_ref<I: Interner>(
db: &dyn RustIrDatabase<I>,
bounds: &Binders<QuantifiedWhereClauses<I>>,
) -> Option<TraitId<I>> {
let interner = db.interner();

) -> Option<Binders<Binders<TraitRef<I>>>> {
bounds
.skip_binders()
.iter(interner)
.filter_map(|b| b.trait_id())
.find(|&id| !db.trait_datum(id).is_auto_trait())
.map_ref(|b| b.iter(db.interner()))
.into_iter()
.find_map(|b| {
b.filter_map(|qwc| {
qwc.as_ref().filter_map(|wc| match wc {
WhereClause::Implemented(trait_ref) => {
if db.trait_datum(trait_ref.trait_id).is_auto_trait() {
None
} else {
Some(trait_ref.clone())
}
}
_ => None,
})
})
})
}

fn auto_trait_ids<'a, I: Interner>(
Expand Down Expand Up @@ -187,10 +198,10 @@ pub fn add_unsize_program_clauses<I: Interner>(
// could be lifted.
//
// for more info visit `fn assemble_candidates_for_unsizing` and
// `fn confirm_builtin_unisize_candidate` in rustc.
// `fn confirm_builtin_unsize_candidate` in rustc.

match (source_ty.kind(interner), target_ty.kind(interner)) {
// dyn Trait + AutoX + 'a -> dyn Trait + AutoY + 'b
// dyn TraitA + AutoA + 'a -> dyn TraitB + AutoB + 'b
(
TyKind::Dyn(DynTy {
bounds: bounds_a,
Expand All @@ -201,13 +212,33 @@ pub fn add_unsize_program_clauses<I: Interner>(
lifetime: lifetime_b,
}),
) => {
let principal_a = principal_id(db, bounds_a);
let principal_b = principal_id(db, bounds_b);
let principal_trait_ref_a = principal_trait_ref(db, bounds_a);
let principal_a = principal_trait_ref_a
.as_ref()
.map(|trait_ref| trait_ref.skip_binders().skip_binders().trait_id);
let principal_b = principal_trait_ref(db, bounds_b)
.map(|trait_ref| trait_ref.skip_binders().skip_binders().trait_id);

// Include super traits in a list of auto traits for A,
// to allow `dyn Trait -> dyn Trait + X` if `Trait: X`.
let auto_trait_ids_a: Vec<_> = auto_trait_ids(db, bounds_a)
.chain(principal_a.into_iter().flat_map(|principal_a| {
super_traits(db, principal_a)
.into_value_and_skipped_binders()
.0
.0
.into_iter()
.map(|x| x.skip_binders().trait_id)
.filter(|&x| db.trait_datum(x).is_auto_trait())
}))
.collect();

let auto_trait_ids_a: Vec<_> = auto_trait_ids(db, bounds_a).collect();
let auto_trait_ids_b: Vec<_> = auto_trait_ids(db, bounds_b).collect();

let may_apply = principal_a == principal_b
// If B has a principal, then A must as well
// (i.e. we allow dropping principal, but not creating a principal out of thin air).
// `AutoB` must be a subset of `AutoA`.
let may_apply = principal_a.is_some() >= principal_b.is_some()
&& auto_trait_ids_b
.iter()
.all(|id_b| auto_trait_ids_a.iter().any(|id_a| id_a == id_b));
Expand All @@ -216,6 +247,13 @@ pub fn add_unsize_program_clauses<I: Interner>(
return;
}

// Check that source lifetime outlives target lifetime
let lifetime_outlives_goal: Goal<I> = WhereClause::LifetimeOutlives(LifetimeOutlives {
a: lifetime_a.clone(),
b: lifetime_b.clone(),
})
.cast(interner);

// COMMENT FROM RUSTC:
// ------------------
// Require that the traits involved in this upcast are **equal**;
Expand All @@ -233,48 +271,138 @@ pub fn add_unsize_program_clauses<I: Interner>(
// with what our behavior should be there. -nikomatsakis
// ------------------

// Construct a new trait object type by taking the source ty,
// filtering out auto traits of source that are not present in target
// and changing source lifetime to target lifetime.
//
// In order for the coercion to be valid, this new type
// should be equal to target type.
let new_source_ty = TyKind::Dyn(DynTy {
bounds: bounds_a.map_ref(|bounds| {
QuantifiedWhereClauses::from_iter(
interner,
bounds.iter(interner).filter(|bound| {
let trait_id = match bound.trait_id() {
Some(id) => id,
None => return true,
};

if auto_trait_ids_a.iter().all(|&id_a| id_a != trait_id) {
return true;
}
auto_trait_ids_b.iter().any(|&id_b| id_b == trait_id)
if principal_a == principal_b || principal_b.is_none() {
// Construct a new trait object type by taking the source ty,
// replacing auto traits of source with those of target,
// and changing source lifetime to target lifetime.
//
// In order for the coercion to be valid, this new type
// should be equal to target type.
let new_source_ty = TyKind::Dyn(DynTy {
bounds: bounds_a.map_ref(|bounds| {
QuantifiedWhereClauses::from_iter(
interner,
bounds
.iter(interner)
.cloned()
.filter_map(|bound| {
let Some(trait_id) = bound.trait_id() else {
// Keep non-"implements" bounds as-is
return Some(bound);
};

// Auto traits are already checked above, ignore them
// (we'll use the ones from B below)
if db.trait_datum(trait_id).is_auto_trait() {
return None;
}

// The only "implements" bound that is not an auto trait, is the principal
assert_eq!(Some(trait_id), principal_a);

// Only include principal_a if the principal_b is also present
// (this allows dropping principal, `dyn Tr+A -> dyn A`)
principal_b.is_some().then(|| bound)
})
// Add auto traits from B (again, they are already checked above).
.chain(bounds_b.skip_binders().iter(interner).cloned().filter(
|bound| {
bound.trait_id().is_some_and(|trait_id| {
db.trait_datum(trait_id).is_auto_trait()
})
},
)),
)
}),
lifetime: lifetime_b.clone(),
})
.intern(interner);

// Check that new source is equal to target
let eq_goal = EqGoal {
a: new_source_ty.cast(interner),
b: target_ty.clone().cast(interner),
}
.cast(interner);

builder.push_clause(trait_ref, [eq_goal, lifetime_outlives_goal].iter());
} else {
// Conditions above imply that both of these are always `Some`
// (b != None, b is Some iff a is Some).
let principal_a = principal_a.unwrap();
let principal_b = principal_b.unwrap();

let principal_trait_ref_a = principal_trait_ref_a.unwrap();
let applicable_super_traits = super_traits(db, principal_a)
.map(|(super_trait_refs, _)| super_trait_refs)
.into_iter()
.filter(|trait_ref| {
trait_ref.skip_binders().skip_binders().trait_id == principal_b
});

for super_trait_ref in applicable_super_traits {
// `super_trait_ref` is, at this point, quantified over generic params of
// `principal_a` and relevant higher-ranked lifetimes that come from super
// trait elaboration (see comments on `super_traits()`).
//
// So if we have `trait Trait<'a, T>: for<'b> Super<'a, 'b, T> {}`,
// `super_trait_ref` can be something like
// `for<Self, 'a, T> for<'b> Self: Super<'a, 'b, T>`.
//
// We need to convert it into a bound for `DynTy`. We do this by substituting
// bound vars of `principal_trait_ref_a` and then fusing inner binders for
// higher-ranked lifetimes.
let rebound_super_trait_ref = principal_trait_ref_a.map_ref(|q_trait_ref_a| {
q_trait_ref_a
.map_ref(|trait_ref_a| {
super_trait_ref.substitute(interner, &trait_ref_a.substitution)
})
.fuse_binders(interner)
});

// Skip `for<Self>` binder. We'll rebind it immediately below.
let new_principal_trait_ref = rebound_super_trait_ref
.into_value_and_skipped_binders()
.0
.map(|it| it.cast(interner));

// Swap trait ref for `principal_a` with the new trait ref, drop the auto
// traits not included in the upcast target.
let new_source_ty = TyKind::Dyn(DynTy {
bounds: bounds_a.map_ref(|bounds| {
QuantifiedWhereClauses::from_iter(
interner,
bounds.iter(interner).cloned().filter_map(|bound| {
let trait_id = match bound.trait_id() {
Some(id) => id,
None => return Some(bound),
};

if principal_a == trait_id {
Some(new_principal_trait_ref.clone())
} else {
auto_trait_ids_b.contains(&trait_id).then_some(bound)
}
}),
)
}),
)
}),
lifetime: lifetime_b.clone(),
})
.intern(interner);
lifetime: lifetime_b.clone(),
})
.intern(interner);

// Check that new source is equal to target
let eq_goal = EqGoal {
a: new_source_ty.cast(interner),
b: target_ty.clone().cast(interner),
}
.cast(interner);

// Check that new source is equal to target
let eq_goal = EqGoal {
a: new_source_ty.cast(interner),
b: target_ty.clone().cast(interner),
// We don't push goal for `principal_b`'s object safety because it's implied by
// `principal_a`'s object safety.
builder
.push_clause(trait_ref.clone(), [eq_goal, lifetime_outlives_goal.clone()]);
}
}
.cast(interner);

// Check that source lifetime outlives target lifetime
let lifetime_outlives_goal: Goal<I> = WhereClause::LifetimeOutlives(LifetimeOutlives {
a: lifetime_a.clone(),
b: lifetime_b.clone(),
})
.cast(interner);

builder.push_clause(trait_ref, [eq_goal, lifetime_outlives_goal].iter());
}

// T -> dyn Trait + 'a
Expand Down
23 changes: 22 additions & 1 deletion chalk-solve/src/clauses/super_traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,28 @@ pub(super) fn push_trait_super_clauses<I: Interner>(
}
}

fn super_traits<I: Interner>(
/// Returns super-`TraitRef`s and super-`Projection`s that are quantified over the parameters of
/// `trait_id` and relevant higher-ranked lifetimes. The outer `Binders` is for the former and the
/// inner `Binders` is for the latter.
///
/// For example, given the following trait definitions and `C` as `trait_id`,
///
/// ```
/// trait A<'a, T> {}
/// trait B<'b, U> where Self: for<'x> A<'x, U> {}
/// trait C<'c, V> where Self: B<'c, V> {}
/// ```
///
/// returns the following quantified `TraitRef`s.
///
/// ```notrust
/// for<Self, 'c, V> {
/// for<'x> { Self: A<'x, V> }
/// for<> { Self: B<'c, V> }
/// for<> { Self: C<'c, V> }
/// }
/// ```
pub(crate) fn super_traits<I: Interner>(
db: &dyn RustIrDatabase<I>,
trait_id: TraitId<I>,
) -> Binders<(
Expand Down
Loading
Loading