Skip to content

Commit

Permalink
Merge pull request #821 from WaffleLapkin/trait_upcast
Browse files Browse the repository at this point in the history
Implement trait upcasting
  • Loading branch information
jackh726 authored Jan 8, 2025
2 parents c83151f + 5689335 commit d2bcd64
Show file tree
Hide file tree
Showing 3 changed files with 350 additions and 55 deletions.
234 changes: 181 additions & 53 deletions chalk-solve/src/clauses/builtin_traits/unsize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use std::collections::HashSet;
use std::iter;
use std::ops::ControlFlow;

use crate::clauses::super_traits::super_traits;
use crate::clauses::ClauseBuilder;
use crate::rust_ir::AdtKind;
use crate::{Interner, RustIrDatabase, TraitRef, WellKnownTrait};
Expand Down Expand Up @@ -136,17 +137,27 @@ fn uses_outer_binder_params<I: Interner>(
matches!(flow, ControlFlow::Break(_))
}

fn principal_id<I: Interner>(
fn principal_trait_ref<I: Interner>(
db: &dyn RustIrDatabase<I>,
bounds: &Binders<QuantifiedWhereClauses<I>>,
) -> Option<TraitId<I>> {
let interner = db.interner();

) -> Option<Binders<Binders<TraitRef<I>>>> {
bounds
.skip_binders()
.iter(interner)
.filter_map(|b| b.trait_id())
.find(|&id| !db.trait_datum(id).is_auto_trait())
.map_ref(|b| b.iter(db.interner()))
.into_iter()
.find_map(|b| {
b.filter_map(|qwc| {
qwc.as_ref().filter_map(|wc| match wc {
WhereClause::Implemented(trait_ref) => {
if db.trait_datum(trait_ref.trait_id).is_auto_trait() {
None
} else {
Some(trait_ref.clone())
}
}
_ => None,
})
})
})
}

fn auto_trait_ids<'a, I: Interner>(
Expand Down Expand Up @@ -187,10 +198,10 @@ pub fn add_unsize_program_clauses<I: Interner>(
// could be lifted.
//
// for more info visit `fn assemble_candidates_for_unsizing` and
// `fn confirm_builtin_unisize_candidate` in rustc.
// `fn confirm_builtin_unsize_candidate` in rustc.

match (source_ty.kind(interner), target_ty.kind(interner)) {
// dyn Trait + AutoX + 'a -> dyn Trait + AutoY + 'b
// dyn TraitA + AutoA + 'a -> dyn TraitB + AutoB + 'b
(
TyKind::Dyn(DynTy {
bounds: bounds_a,
Expand All @@ -201,13 +212,33 @@ pub fn add_unsize_program_clauses<I: Interner>(
lifetime: lifetime_b,
}),
) => {
let principal_a = principal_id(db, bounds_a);
let principal_b = principal_id(db, bounds_b);
let principal_trait_ref_a = principal_trait_ref(db, bounds_a);
let principal_a = principal_trait_ref_a
.as_ref()
.map(|trait_ref| trait_ref.skip_binders().skip_binders().trait_id);
let principal_b = principal_trait_ref(db, bounds_b)
.map(|trait_ref| trait_ref.skip_binders().skip_binders().trait_id);

// Include super traits in a list of auto traits for A,
// to allow `dyn Trait -> dyn Trait + X` if `Trait: X`.
let auto_trait_ids_a: Vec<_> = auto_trait_ids(db, bounds_a)
.chain(principal_a.into_iter().flat_map(|principal_a| {
super_traits(db, principal_a)
.into_value_and_skipped_binders()
.0
.0
.into_iter()
.map(|x| x.skip_binders().trait_id)
.filter(|&x| db.trait_datum(x).is_auto_trait())
}))
.collect();

let auto_trait_ids_a: Vec<_> = auto_trait_ids(db, bounds_a).collect();
let auto_trait_ids_b: Vec<_> = auto_trait_ids(db, bounds_b).collect();

let may_apply = principal_a == principal_b
// If B has a principal, then A must as well
// (i.e. we allow dropping principal, but not creating a principal out of thin air).
// `AutoB` must be a subset of `AutoA`.
let may_apply = principal_a.is_some() >= principal_b.is_some()
&& auto_trait_ids_b
.iter()
.all(|id_b| auto_trait_ids_a.iter().any(|id_a| id_a == id_b));
Expand All @@ -216,6 +247,13 @@ pub fn add_unsize_program_clauses<I: Interner>(
return;
}

// Check that source lifetime outlives target lifetime
let lifetime_outlives_goal: Goal<I> = WhereClause::LifetimeOutlives(LifetimeOutlives {
a: lifetime_a.clone(),
b: lifetime_b.clone(),
})
.cast(interner);

// COMMENT FROM RUSTC:
// ------------------
// Require that the traits involved in this upcast are **equal**;
Expand All @@ -233,48 +271,138 @@ pub fn add_unsize_program_clauses<I: Interner>(
// with what our behavior should be there. -nikomatsakis
// ------------------

// Construct a new trait object type by taking the source ty,
// filtering out auto traits of source that are not present in target
// and changing source lifetime to target lifetime.
//
// In order for the coercion to be valid, this new type
// should be equal to target type.
let new_source_ty = TyKind::Dyn(DynTy {
bounds: bounds_a.map_ref(|bounds| {
QuantifiedWhereClauses::from_iter(
interner,
bounds.iter(interner).filter(|bound| {
let trait_id = match bound.trait_id() {
Some(id) => id,
None => return true,
};

if auto_trait_ids_a.iter().all(|&id_a| id_a != trait_id) {
return true;
}
auto_trait_ids_b.iter().any(|&id_b| id_b == trait_id)
if principal_a == principal_b || principal_b.is_none() {
// Construct a new trait object type by taking the source ty,
// replacing auto traits of source with those of target,
// and changing source lifetime to target lifetime.
//
// In order for the coercion to be valid, this new type
// should be equal to target type.
let new_source_ty = TyKind::Dyn(DynTy {
bounds: bounds_a.map_ref(|bounds| {
QuantifiedWhereClauses::from_iter(
interner,
bounds
.iter(interner)
.cloned()
.filter_map(|bound| {
let Some(trait_id) = bound.trait_id() else {
// Keep non-"implements" bounds as-is
return Some(bound);
};

// Auto traits are already checked above, ignore them
// (we'll use the ones from B below)
if db.trait_datum(trait_id).is_auto_trait() {
return None;
}

// The only "implements" bound that is not an auto trait, is the principal
assert_eq!(Some(trait_id), principal_a);

// Only include principal_a if the principal_b is also present
// (this allows dropping principal, `dyn Tr+A -> dyn A`)
principal_b.is_some().then(|| bound)
})
// Add auto traits from B (again, they are already checked above).
.chain(bounds_b.skip_binders().iter(interner).cloned().filter(
|bound| {
bound.trait_id().is_some_and(|trait_id| {
db.trait_datum(trait_id).is_auto_trait()
})
},
)),
)
}),
lifetime: lifetime_b.clone(),
})
.intern(interner);

// Check that new source is equal to target
let eq_goal = EqGoal {
a: new_source_ty.cast(interner),
b: target_ty.clone().cast(interner),
}
.cast(interner);

builder.push_clause(trait_ref, [eq_goal, lifetime_outlives_goal].iter());
} else {
// Conditions above imply that both of these are always `Some`
// (b != None, b is Some iff a is Some).
let principal_a = principal_a.unwrap();
let principal_b = principal_b.unwrap();

let principal_trait_ref_a = principal_trait_ref_a.unwrap();
let applicable_super_traits = super_traits(db, principal_a)
.map(|(super_trait_refs, _)| super_trait_refs)
.into_iter()
.filter(|trait_ref| {
trait_ref.skip_binders().skip_binders().trait_id == principal_b
});

for super_trait_ref in applicable_super_traits {
// `super_trait_ref` is, at this point, quantified over generic params of
// `principal_a` and relevant higher-ranked lifetimes that come from super
// trait elaboration (see comments on `super_traits()`).
//
// So if we have `trait Trait<'a, T>: for<'b> Super<'a, 'b, T> {}`,
// `super_trait_ref` can be something like
// `for<Self, 'a, T> for<'b> Self: Super<'a, 'b, T>`.
//
// We need to convert it into a bound for `DynTy`. We do this by substituting
// bound vars of `principal_trait_ref_a` and then fusing inner binders for
// higher-ranked lifetimes.
let rebound_super_trait_ref = principal_trait_ref_a.map_ref(|q_trait_ref_a| {
q_trait_ref_a
.map_ref(|trait_ref_a| {
super_trait_ref.substitute(interner, &trait_ref_a.substitution)
})
.fuse_binders(interner)
});

// Skip `for<Self>` binder. We'll rebind it immediately below.
let new_principal_trait_ref = rebound_super_trait_ref
.into_value_and_skipped_binders()
.0
.map(|it| it.cast(interner));

// Swap trait ref for `principal_a` with the new trait ref, drop the auto
// traits not included in the upcast target.
let new_source_ty = TyKind::Dyn(DynTy {
bounds: bounds_a.map_ref(|bounds| {
QuantifiedWhereClauses::from_iter(
interner,
bounds.iter(interner).cloned().filter_map(|bound| {
let trait_id = match bound.trait_id() {
Some(id) => id,
None => return Some(bound),
};

if principal_a == trait_id {
Some(new_principal_trait_ref.clone())
} else {
auto_trait_ids_b.contains(&trait_id).then_some(bound)
}
}),
)
}),
)
}),
lifetime: lifetime_b.clone(),
})
.intern(interner);
lifetime: lifetime_b.clone(),
})
.intern(interner);

// Check that new source is equal to target
let eq_goal = EqGoal {
a: new_source_ty.cast(interner),
b: target_ty.clone().cast(interner),
}
.cast(interner);

// Check that new source is equal to target
let eq_goal = EqGoal {
a: new_source_ty.cast(interner),
b: target_ty.clone().cast(interner),
// We don't push goal for `principal_b`'s object safety because it's implied by
// `principal_a`'s object safety.
builder
.push_clause(trait_ref.clone(), [eq_goal, lifetime_outlives_goal.clone()]);
}
}
.cast(interner);

// Check that source lifetime outlives target lifetime
let lifetime_outlives_goal: Goal<I> = WhereClause::LifetimeOutlives(LifetimeOutlives {
a: lifetime_a.clone(),
b: lifetime_b.clone(),
})
.cast(interner);

builder.push_clause(trait_ref, [eq_goal, lifetime_outlives_goal].iter());
}

// T -> dyn Trait + 'a
Expand Down
23 changes: 22 additions & 1 deletion chalk-solve/src/clauses/super_traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,28 @@ pub(super) fn push_trait_super_clauses<I: Interner>(
}
}

fn super_traits<I: Interner>(
/// Returns super-`TraitRef`s and super-`Projection`s that are quantified over the parameters of
/// `trait_id` and relevant higher-ranked lifetimes. The outer `Binders` is for the former and the
/// inner `Binders` is for the latter.
///
/// For example, given the following trait definitions and `C` as `trait_id`,
///
/// ```
/// trait A<'a, T> {}
/// trait B<'b, U> where Self: for<'x> A<'x, U> {}
/// trait C<'c, V> where Self: B<'c, V> {}
/// ```
///
/// returns the following quantified `TraitRef`s.
///
/// ```notrust
/// for<Self, 'c, V> {
/// for<'x> { Self: A<'x, V> }
/// for<> { Self: B<'c, V> }
/// for<> { Self: C<'c, V> }
/// }
/// ```
pub(crate) fn super_traits<I: Interner>(
db: &dyn RustIrDatabase<I>,
trait_id: TraitId<I>,
) -> Binders<(
Expand Down
Loading

0 comments on commit d2bcd64

Please sign in to comment.