Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 99eed4b

Browse files
committedOct 20, 2024··
mir-opt: Merge all branch BBs into a single copy statement for enum
1 parent 70392ec commit 99eed4b

13 files changed

+560
-61
lines changed
 

‎compiler/rustc_mir_transform/src/lib.rs

+2
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ mod lower_intrinsics;
8686
mod lower_slice_len;
8787
mod match_branches;
8888
mod mentioned_items;
89+
mod merge_branches;
8990
mod multiple_return_terminators;
9091
mod nrvo;
9192
mod post_drop_elaboration;
@@ -611,6 +612,7 @@ fn run_optimization_passes<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
611612
&dead_store_elimination::DeadStoreElimination::Initial,
612613
&gvn::GVN,
613614
&simplify::SimplifyLocals::AfterGVN,
615+
&merge_branches::MergeBranchSimplification,
614616
&dataflow_const_prop::DataflowConstProp,
615617
&single_use_consts::SingleUseConsts,
616618
&o1(simplify_branches::SimplifyConstCondition::AfterConstProp),
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,240 @@
1+
//! This pass attempts to merge all branches to eliminate switch terminator.
2+
//! Ideally, we could combine it with `MatchBranchSimplification`, as these two passes
3+
//! match and merge statements with different patterns. Given the compile time and
4+
//! code complexity, we have not merged them into a more general pass for now.
5+
use rustc_const_eval::const_eval::mk_eval_cx_for_const_val;
6+
use rustc_index::bit_set::BitSet;
7+
use rustc_middle::mir::patch::MirPatch;
8+
use rustc_middle::mir::*;
9+
use rustc_middle::ty;
10+
use rustc_middle::ty::util::Discr;
11+
use rustc_middle::ty::{ParamEnv, TyCtxt};
12+
use rustc_mir_dataflow::impls::borrowed_locals;
13+
14+
use crate::dead_store_elimination::DeadStoreAnalysis;
15+
16+
pub(super) struct MergeBranchSimplification;
17+
18+
impl<'tcx> crate::MirPass<'tcx> for MergeBranchSimplification {
19+
fn is_enabled(&self, sess: &rustc_session::Session) -> bool {
20+
sess.mir_opt_level() >= 2
21+
}
22+
23+
fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
24+
let def_id = body.source.def_id();
25+
let param_env = tcx.param_env_reveal_all_normalized(def_id);
26+
27+
let borrowed_locals = borrowed_locals(body);
28+
let mut dead_store_analysis = DeadStoreAnalysis::new(tcx, body, &borrowed_locals);
29+
30+
for switch_bb_idx in body.basic_blocks.indices() {
31+
let bbs = &*body.basic_blocks;
32+
let Some((switch_discr, targets)) = bbs[switch_bb_idx].terminator().kind.as_switch()
33+
else {
34+
continue;
35+
};
36+
// Check that destinations are identical, and if not, then don't optimize this block.
37+
let mut targets_iter = targets.iter();
38+
let first_terminator_kind = &bbs[targets_iter.next().unwrap().1].terminator().kind;
39+
if targets_iter.any(|(_, other_target)| {
40+
first_terminator_kind != &bbs[other_target].terminator().kind
41+
}) {
42+
continue;
43+
}
44+
// We require that the possible target blocks all be distinct.
45+
if !targets.is_distinct() {
46+
continue;
47+
}
48+
if !bbs[targets.otherwise()].is_empty_unreachable() {
49+
continue;
50+
}
51+
// Check if the copy source matches the following pattern.
52+
// _2 = discriminant(*_1); // "*_1" is the expected the copy source.
53+
// switchInt(move _2) -> [0: bb3, 1: bb2, otherwise: bb1];
54+
let Some(&Statement {
55+
kind: StatementKind::Assign(box (discr_place, Rvalue::Discriminant(src_place))),
56+
..
57+
}) = bbs[switch_bb_idx].statements.last()
58+
else {
59+
continue;
60+
};
61+
if switch_discr.place() != Some(discr_place) {
62+
continue;
63+
}
64+
let src_ty = src_place.ty(body.local_decls(), tcx);
65+
if let Some(dest_place) = can_simplify_to_copy(
66+
tcx,
67+
param_env,
68+
body,
69+
targets,
70+
src_place,
71+
src_ty,
72+
&mut dead_store_analysis,
73+
) {
74+
let statement_index = bbs[switch_bb_idx].statements.len();
75+
let parent_end = Location { block: switch_bb_idx, statement_index };
76+
let mut patch = MirPatch::new(body);
77+
patch.add_assign(parent_end, dest_place, Rvalue::Use(Operand::Copy(src_place)));
78+
patch.patch_terminator(switch_bb_idx, first_terminator_kind.clone());
79+
patch.apply(body);
80+
super::simplify::remove_dead_blocks(body);
81+
// After modifying the MIR, the result of `MaybeTransitiveLiveLocals` may become invalid,
82+
// keeping it simple to process only once.
83+
break;
84+
}
85+
}
86+
}
87+
}
88+
89+
/// The GVN simplified
90+
/// ```ignore (syntax-highlighting-only)
91+
/// match a {
92+
/// Foo::A(x) => Foo::A(*x),
93+
/// Foo::B => Foo::B
94+
/// }
95+
/// ```
96+
/// to
97+
/// ```ignore (syntax-highlighting-only)
98+
/// match a {
99+
/// Foo::A(_x) => a, // copy a
100+
/// Foo::B => Foo::B
101+
/// }
102+
/// ```
103+
/// This function answers whether it can be simplified to a copy statement
104+
/// by returning the copy destination.
105+
fn can_simplify_to_copy<'tcx>(
106+
tcx: TyCtxt<'tcx>,
107+
param_env: ParamEnv<'tcx>,
108+
body: &Body<'tcx>,
109+
targets: &SwitchTargets,
110+
src_place: Place<'tcx>,
111+
src_ty: tcx::PlaceTy<'tcx>,
112+
dead_store_analysis: &mut DeadStoreAnalysis<'tcx, '_, '_>,
113+
) -> Option<Place<'tcx>> {
114+
let mut targets_iter = targets.iter();
115+
let (first_index, first_target) = targets_iter.next()?;
116+
let dest_place = find_copy_assign(
117+
tcx,
118+
param_env,
119+
body,
120+
first_index,
121+
first_target,
122+
src_place,
123+
src_ty,
124+
dead_store_analysis,
125+
)?;
126+
let dest_ty = dest_place.ty(body.local_decls(), tcx);
127+
if dest_ty.ty != src_ty.ty {
128+
return None;
129+
}
130+
for (other_index, other_target) in targets_iter {
131+
if dest_place
132+
!= find_copy_assign(
133+
tcx,
134+
param_env,
135+
body,
136+
other_index,
137+
other_target,
138+
src_place,
139+
src_ty,
140+
dead_store_analysis,
141+
)?
142+
{
143+
return None;
144+
}
145+
}
146+
Some(dest_place)
147+
}
148+
149+
// Find the single assignment statement where the source of the copy is from the source.
150+
// All other statements are dead statements or have no effect that can be eliminated.
151+
fn find_copy_assign<'tcx>(
152+
tcx: TyCtxt<'tcx>,
153+
param_env: ParamEnv<'tcx>,
154+
body: &Body<'tcx>,
155+
index: u128,
156+
target_block: BasicBlock,
157+
src_place: Place<'tcx>,
158+
src_ty: tcx::PlaceTy<'tcx>,
159+
dead_store_analysis: &mut DeadStoreAnalysis<'tcx, '_, '_>,
160+
) -> Option<Place<'tcx>> {
161+
let statements = &body.basic_blocks[target_block].statements;
162+
if statements.is_empty() {
163+
return None;
164+
}
165+
let assign_stmt = if statements.len() == 1 {
166+
0
167+
} else {
168+
let mut lived_stmts: BitSet<usize> = BitSet::new_filled(statements.len());
169+
let mut expected_assign_stmt = None;
170+
for (statement_index, statement) in statements.iter().enumerate().rev() {
171+
let loc = Location { block: target_block, statement_index };
172+
if dead_store_analysis.is_dead_store(loc, &statement.kind) {
173+
lived_stmts.remove(statement_index);
174+
} else if matches!(
175+
statement.kind,
176+
StatementKind::StorageLive(_) | StatementKind::StorageDead(_)
177+
) {
178+
} else if matches!(statement.kind, StatementKind::Assign(_))
179+
&& expected_assign_stmt.is_none()
180+
{
181+
// There is only one assign statement that cannot be ignored
182+
// that can be used as an expected copy statement.
183+
expected_assign_stmt = Some(statement_index);
184+
lived_stmts.remove(statement_index);
185+
} else {
186+
return None;
187+
}
188+
}
189+
let expected_assign = expected_assign_stmt?;
190+
if !lived_stmts.is_empty() {
191+
// We can ignore the paired StorageLive and StorageDead.
192+
let mut storage_live_locals: BitSet<Local> = BitSet::new_empty(body.local_decls.len());
193+
for stmt_index in lived_stmts.iter() {
194+
let statement = &statements[stmt_index];
195+
match &statement.kind {
196+
StatementKind::StorageLive(local) if storage_live_locals.insert(*local) => {}
197+
StatementKind::StorageDead(local) if storage_live_locals.remove(*local) => {}
198+
_ => return None,
199+
}
200+
}
201+
if !storage_live_locals.is_empty() {
202+
return None;
203+
}
204+
}
205+
expected_assign
206+
};
207+
let &(dest_place, ref rvalue) = statements[assign_stmt].kind.as_assign()?;
208+
let dest_ty = dest_place.ty(body.local_decls(), tcx);
209+
if dest_ty.ty != src_ty.ty {
210+
return None;
211+
}
212+
let ty::Adt(def, _) = dest_ty.ty.kind() else {
213+
return None;
214+
};
215+
match rvalue {
216+
// Check if `_3 = const Foo::B` can be transformed to `_3 = copy *_1`.
217+
Rvalue::Use(Operand::Constant(box constant))
218+
if let Const::Val(const_, ty) = constant.const_ =>
219+
{
220+
let (ecx, op) = mk_eval_cx_for_const_val(tcx.at(constant.span), param_env, const_, ty)?;
221+
let variant = ecx.read_discriminant(&op).discard_err()?;
222+
if !def.variants()[variant].fields.is_empty() {
223+
return None;
224+
}
225+
let Discr { val, .. } = ty.discriminant_for_variant(tcx, variant)?;
226+
if val != index {
227+
return None;
228+
}
229+
}
230+
Rvalue::Use(Operand::Copy(place)) if *place == src_place => {}
231+
// Check if `_3 = Foo::B` can be transformed to `_3 = copy *_1`.
232+
Rvalue::Aggregate(box AggregateKind::Adt(_, variant_index, _, _, None), fields)
233+
if fields.is_empty()
234+
&& let Some(Discr { val, .. }) =
235+
src_ty.ty.discriminant_for_variant(tcx, *variant_index)
236+
&& val == index => {}
237+
_ => return None,
238+
}
239+
Some(dest_place)
240+
}

‎tests/codegen/match-optimizes-away.rs

+8-6
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
//
2-
//@ compile-flags: -O
1+
//@ compile-flags: -O -Cno-prepopulate-passes
2+
33
#![crate_type = "lib"]
44

55
pub enum Three {
@@ -19,8 +19,9 @@ pub enum Four {
1919
#[no_mangle]
2020
pub fn three_valued(x: Three) -> Three {
2121
// CHECK-LABEL: @three_valued
22-
// CHECK-NEXT: {{^.*:$}}
23-
// CHECK-NEXT: ret i8 %0
22+
// CHECK-SAME: (i8{{.*}} [[X:%x]])
23+
// CHECK-NEXT: start:
24+
// CHECK-NEXT: ret i8 [[X]]
2425
match x {
2526
Three::A => Three::A,
2627
Three::B => Three::B,
@@ -31,8 +32,9 @@ pub fn three_valued(x: Three) -> Three {
3132
#[no_mangle]
3233
pub fn four_valued(x: Four) -> Four {
3334
// CHECK-LABEL: @four_valued
34-
// CHECK-NEXT: {{^.*:$}}
35-
// CHECK-NEXT: ret i16 %0
35+
// CHECK-SAME: (i16{{.*}} [[X:%x]])
36+
// CHECK-NEXT: start:
37+
// CHECK-NEXT: ret i16 [[X]]
3638
match x {
3739
Four::A => Four::A,
3840
Four::B => Four::B,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
- // MIR for `no_fields` before MergeBranchSimplification
2+
+ // MIR for `no_fields` after MergeBranchSimplification
3+
4+
fn no_fields(_1: NoFields) -> NoFields {
5+
debug a => _1;
6+
let mut _0: NoFields;
7+
let mut _2: isize;
8+
9+
bb0: {
10+
_2 = discriminant(_1);
11+
- switchInt(move _2) -> [0: bb3, 1: bb2, otherwise: bb1];
12+
+ _0 = copy _1;
13+
+ goto -> bb1;
14+
}
15+
16+
bb1: {
17+
- unreachable;
18+
- }
19+
-
20+
- bb2: {
21+
- _0 = NoFields::B;
22+
- goto -> bb4;
23+
- }
24+
-
25+
- bb3: {
26+
- _0 = NoFields::A;
27+
- goto -> bb4;
28+
- }
29+
-
30+
- bb4: {
31+
return;
32+
}
33+
}
34+
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
- // MIR for `no_fields_failed` before MergeBranchSimplification
2+
+ // MIR for `no_fields_failed` after MergeBranchSimplification
3+
4+
fn no_fields_failed(_1: NoFields) -> NoFields {
5+
debug a => _1;
6+
let mut _0: NoFields;
7+
let mut _2: isize;
8+
9+
bb0: {
10+
_2 = discriminant(_1);
11+
switchInt(move _2) -> [0: bb3, 1: bb2, otherwise: bb1];
12+
}
13+
14+
bb1: {
15+
unreachable;
16+
}
17+
18+
bb2: {
19+
_0 = NoFields::A;
20+
goto -> bb4;
21+
}
22+
23+
bb3: {
24+
_0 = NoFields::B;
25+
goto -> bb4;
26+
}
27+
28+
bb4: {
29+
return;
30+
}
31+
}
32+
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
- // MIR for `no_fields_mismatch_type_failed` before MergeBranchSimplification
2+
+ // MIR for `no_fields_mismatch_type_failed` after MergeBranchSimplification
3+
4+
fn no_fields_mismatch_type_failed(_1: NoFields) -> NoFields2 {
5+
debug a => _1;
6+
let mut _0: NoFields2;
7+
let mut _2: isize;
8+
9+
bb0: {
10+
_2 = discriminant(_1);
11+
switchInt(move _2) -> [0: bb3, 1: bb2, otherwise: bb1];
12+
}
13+
14+
bb1: {
15+
unreachable;
16+
}
17+
18+
bb2: {
19+
_0 = NoFields2::B;
20+
goto -> bb4;
21+
}
22+
23+
bb3: {
24+
_0 = NoFields2::A;
25+
goto -> bb4;
26+
}
27+
28+
bb4: {
29+
return;
30+
}
31+
}
32+

0 commit comments

Comments
 (0)
Please sign in to comment.