Skip to content

Commit

Permalink
Prune unreachable variants of coroutines
Browse files Browse the repository at this point in the history
  • Loading branch information
compiler-errors committed Jan 14, 2025
1 parent 3736b85 commit 802fb2a
Show file tree
Hide file tree
Showing 4 changed files with 534 additions and 0 deletions.
167 changes: 167 additions & 0 deletions compiler/rustc_mir_transform/src/simplify.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@
//! naively generate still contains the `_a = ()` write in the unreachable block "after" the
//! return.
use rustc_abi::{FieldIdx, VariantIdx};
use rustc_data_structures::fx::FxHashSet;
use rustc_hir::{CoroutineDesugaring, CoroutineKind};
use rustc_index::bit_set::DenseBitSet;
use rustc_index::{Idx, IndexSlice, IndexVec};
use rustc_middle::mir::visit::{MutVisitor, MutatingUseContext, PlaceContext, Visitor};
use rustc_middle::mir::*;
Expand Down Expand Up @@ -68,6 +72,7 @@ impl SimplifyCfg {

pub(super) fn simplify_cfg(body: &mut Body<'_>) {
CfgSimplifier::new(body).simplify();
remove_dead_coroutine_switch_variants(body);
remove_dead_blocks(body);

// FIXME: Should probably be moved into some kind of pass manager
Expand Down Expand Up @@ -292,6 +297,168 @@ pub(super) fn simplify_duplicate_switch_targets(terminator: &mut Terminator<'_>)
}
}

const SELF_LOCAL: Local = Local::from_u32(1);
const FIELD_ZERO: FieldIdx = FieldIdx::from_u32(0);

pub(super) fn remove_dead_coroutine_switch_variants(body: &mut Body<'_>) {
let Some(coroutine_layout) = body.coroutine_layout_raw() else {
// Not a coroutine; no coroutine variants to remove.
return;
};

let bb0 = &body.basic_blocks[START_BLOCK];

let is_pinned = match body.coroutine_kind().unwrap() {
CoroutineKind::Desugared(CoroutineDesugaring::Gen, _) => false,
CoroutineKind::Desugared(CoroutineDesugaring::Async, _)
| CoroutineKind::Desugared(CoroutineDesugaring::AsyncGen, _)
| CoroutineKind::Coroutine(_) => true,
};
// This is essentially our off-brand `Underefer`. This stores the set of locals
// that we have determined to contain references to the coroutine discriminant.
// If the self type is not pinned, this is just going to be `_1`. However, if
// the self type is pinned, the derefer will emit statements of the form:
// _x = CopyForDeref (_1.0);
// We'll store the local for `_x` so that we can later detect discriminant stores
// of the form:
// Discriminant((*_x)) = ...
// which correspond to reachable variants of the coroutine.
let mut discr_locals = if is_pinned {
let Some(stmt) = bb0.statements.get(0) else {
// The coroutine body may have been turned into a single `unreachable`.
return;
};
// We match `CopyForDeref` (which is what gets emitted from the state transform
// pass), but also we match *regular* `Copy`, which is what GVN may optimize it to.
let StatementKind::Assign(box (
place,
Rvalue::Use(Operand::Copy(deref_place)) | Rvalue::CopyForDeref(deref_place),
)) = &stmt.kind
else {
panic!("The first statement of a coroutine is not a self deref");
};
let PlaceRef { local: SELF_LOCAL, projection: &[PlaceElem::Field(FIELD_ZERO, _)] } =
deref_place.as_ref()
else {
panic!("The first statement of a coroutine is not a self deref");
};
FxHashSet::from_iter([place.as_local().unwrap()])
} else {
FxHashSet::from_iter([SELF_LOCAL])
};

// The starting block of all coroutines is a switch for the coroutine variants.
// This is preceded by a read of the discriminant. If we don't find this, then
// we must have optimized away the switch, so bail.
let StatementKind::Assign(box (discr_place, Rvalue::Discriminant(discr_local))) =
&bb0.statements[if is_pinned { 1 } else { 0 }].kind
else {
// The following statement is not a discriminant read. We may have
// optimized it out, so bail gracefully.
return;
};
let PlaceRef { local: deref_local, projection: &[PlaceElem::Deref] } = (*discr_local).as_ref()
else {
// We expect the discriminant to have read `&mut self`,
// so we expect the place to be a deref. If we didn't, then
// it may have been optimized out, so bail gracefully.
return;
};
if !discr_locals.contains(&deref_local) {
// The place being read isn't `_1` (self) or a `Derefer`-inserted local.
// It may have been optimized out, so bail gracefully.
return;
}
let TerminatorKind::SwitchInt { discr: Operand::Move(place), targets } = &bb0.terminator().kind
else {
// When panic=abort, we may end up folding away the other variants of the
// coroutine, and end up with ths `SwitchInt` getting replaced. In this
// case, there's no need to do this optimization, so bail gracefully.
return;
};
if place != discr_place {
// Make sure we don't try to match on some other `SwitchInt`; we should be
// matching on the discriminant we just read.
return;
}

let mut visited = DenseBitSet::new_empty(body.basic_blocks.len());
let mut worklist = vec![];
let mut visited_variants = DenseBitSet::new_empty(coroutine_layout.variant_fields.len());

// Insert unresumed (initial), returned, panicked variants.
// We treat these as always reachable.
visited_variants.insert(VariantIdx::from_usize(0));
visited_variants.insert(VariantIdx::from_usize(1));
visited_variants.insert(VariantIdx::from_usize(2));
worklist.push(targets.target_for_value(0));
worklist.push(targets.target_for_value(1));
worklist.push(targets.target_for_value(2));

// Walk all of the reachable variant blocks.
while let Some(block) = worklist.pop() {
if !visited.insert(block) {
continue;
}

let data = &body.basic_blocks[block];
for stmt in &data.statements {
match &stmt.kind {
// If we see a `SetDiscriminant` statement for our coroutine,
// mark that variant as reachable and add it to the worklist.
StatementKind::SetDiscriminant { place, variant_index } => {
let PlaceRef { local: deref_local, projection: &[PlaceElem::Deref] } =
(**place).as_ref()
else {
continue;
};
if !discr_locals.contains(&deref_local) {
continue;
}
visited_variants.insert(*variant_index);
worklist.push(targets.target_for_value(variant_index.as_u32().into()));
}
// The derefer may have inserted a local to access the variant.
// Make sure we keep track of it here.
StatementKind::Assign(box (place, Rvalue::CopyForDeref(deref_place))) => {
if !is_pinned {
continue;
}
let PlaceRef {
local: SELF_LOCAL,
projection: &[PlaceElem::Field(FIELD_ZERO, _)],
} = deref_place.as_ref()
else {
continue;
};
discr_locals.insert(place.as_local().unwrap());
}
_ => {}
}
}

// Also walk all the successors of this block.
if let Some(term) = &data.terminator {
worklist.extend(term.successors());
}
}

// Filter out the variants that are unreachable.
let TerminatorKind::SwitchInt { targets, .. } =
&mut body.basic_blocks.as_mut()[START_BLOCK].terminator_mut().kind
else {
unreachable!();
};
*targets = SwitchTargets::new(
targets
.iter()
.filter(|(idx, _)| visited_variants.contains(VariantIdx::from_u32(*idx as u32))),
targets.otherwise(),
);

// FIXME: We could remove dead variant fields from the coroutine layout, too.
}

pub(super) fn remove_dead_blocks(body: &mut Body<'_>) {
let should_deduplicate_unreachable = |bbdata: &BasicBlockData<'_>| {
// CfgSimplifier::simplify leaves behind some unreachable basic blocks without a
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
- // MIR for `outer::{closure#0}` before SimplifyCfg-final
+ // MIR for `outer::{closure#0}` after SimplifyCfg-final
/* coroutine_layout = CoroutineLayout {
field_tys: {
_0: CoroutineSavedTy {
ty: Coroutine(
DefId(0:5 ~ coroutine_dead_variants[61c4]::inner::{closure#0}),
[
(),
std::future::ResumeTy,
(),
(),
CoroutineWitness(
DefId(0:5 ~ coroutine_dead_variants[61c4]::inner::{closure#0}),
[],
),
(),
],
),
source_info: SourceInfo {
span: $DIR/coroutine_dead_variants.rs:13:9: 13:22 (#16),
scope: scope[0],
},
ignore_for_traits: false,
},
},
variant_fields: {
Unresumed(0): [],
Returned (1): [],
Panicked (2): [],
Suspend0 (3): [_0],
},
storage_conflicts: BitMatrix(1x1) {
(_0, _0),
},
} */

fn outer::{closure#0}(_1: Pin<&mut {async fn body of outer()}>, _2: &mut Context<'_>) -> Poll<()> {
debug _task_context => _2;
let mut _0: std::task::Poll<()>;
let mut _3: {async fn body of inner()};
let mut _4: {async fn body of inner()};
let mut _5: std::task::Poll<()>;
let mut _6: std::pin::Pin<&mut {async fn body of inner()}>;
let mut _7: &mut {async fn body of inner()};
let mut _8: &mut std::task::Context<'_>;
let mut _9: isize;
let mut _11: ();
let mut _12: &mut std::task::Context<'_>;
let mut _13: u32;
let mut _14: &mut {async fn body of outer()};
scope 1 {
debug __awaitee => (((*(_1.0: &mut {async fn body of outer()})) as variant#3).0: {async fn body of inner()});
let _10: ();
scope 2 {
debug result => const ();
}
}

bb0: {
_14 = copy (_1.0: &mut {async fn body of outer()});
_13 = discriminant((*_14));
- switchInt(move _13) -> [0: bb1, 1: bb15, 3: bb14, otherwise: bb8];
+ switchInt(move _13) -> [0: bb2, 1: bb4, otherwise: bb1];
}

bb1: {
- nop;
- goto -> bb12;
- }
-
- bb2: {
- StorageLive(_3);
- StorageLive(_4);
- _4 = inner() -> [return: bb3, unwind unreachable];
- }
-
- bb3: {
- _3 = <{async fn body of inner()} as IntoFuture>::into_future(move _4) -> [return: bb4, unwind unreachable];
- }
-
- bb4: {
- StorageDead(_4);
- (((*_14) as variant#3).0: {async fn body of inner()}) = move _3;
- goto -> bb5;
- }
-
- bb5: {
- StorageLive(_5);
- StorageLive(_6);
- _7 = &mut (((*_14) as variant#3).0: {async fn body of inner()});
- _6 = Pin::<&mut {async fn body of inner()}>::new_unchecked(copy _7) -> [return: bb6, unwind unreachable];
- }
-
- bb6: {
- nop;
- _5 = <{async fn body of inner()} as Future>::poll(move _6, copy _2) -> [return: bb7, unwind unreachable];
- }
-
- bb7: {
- StorageDead(_6);
- _9 = discriminant(_5);
- switchInt(move _9) -> [0: bb10, 1: bb9, otherwise: bb8];
- }
-
- bb8: {
unreachable;
}

- bb9: {
- StorageDead(_5);
- _0 = const Poll::<()>::Pending;
- StorageDead(_3);
- discriminant((*_14)) = 3;
- return;
- }
-
- bb10: {
- StorageLive(_10);
- nop;
- StorageDead(_10);
- StorageDead(_5);
- drop((((*_14) as variant#3).0: {async fn body of inner()})) -> [return: bb11, unwind unreachable];
- }
-
- bb11: {
- StorageDead(_3);
+ bb2: {
_11 = const ();
- goto -> bb13;
+ goto -> bb3;
}

- bb12: {
- _11 = const ();
- goto -> bb13;
- }
-
- bb13: {
+ bb3: {
_0 = Poll::<()>::Ready(const ());
discriminant((*_14)) = 1;
return;
}

- bb14: {
- StorageLive(_3);
- nop;
- goto -> bb5;
- }
-
- bb15: {
- assert(const false, "`async fn` resumed after completion") -> [success: bb15, unwind unreachable];
+ bb4: {
+ assert(const false, "`async fn` resumed after completion") -> [success: bb4, unwind unreachable];
}
}

Loading

0 comments on commit 802fb2a

Please sign in to comment.