@@ -132,15 +132,12 @@ use std::collections::hash_map::{Entry, OccupiedEntry};
132132use  crate :: MirPass ; 
133133use  rustc_data_structures:: fx:: FxHashMap ; 
134134use  rustc_index:: bit_set:: BitSet ; 
135+ use  rustc_middle:: mir:: visit:: { MutVisitor ,  PlaceContext ,  Visitor } ; 
135136use  rustc_middle:: mir:: { dump_mir,  PassWhere } ; 
136137use  rustc_middle:: mir:: { 
137138    traversal,  BasicBlock ,  Body ,  InlineAsmOperand ,  Local ,  LocalKind ,  Location ,  Operand ,  Place , 
138139    Rvalue ,  Statement ,  StatementKind ,  TerminatorKind , 
139140} ; 
140- use  rustc_middle:: mir:: { 
141-     visit:: { MutVisitor ,  PlaceContext ,  Visitor } , 
142-     ProjectionElem , 
143- } ; 
144141use  rustc_middle:: ty:: TyCtxt ; 
145142use  rustc_mir_dataflow:: impls:: MaybeLiveLocals ; 
146143use  rustc_mir_dataflow:: { Analysis ,  ResultsCursor } ; 
@@ -359,40 +356,45 @@ struct FilterInformation<'a, 'body, 'alloc, 'tcx> {
359356// through these methods, and not directly. 
360357impl < ' alloc >  Candidates < ' alloc >  { 
361358    /// Just `Vec::retain`, but the condition is inverted and we add debugging output 
362-      fn  vec_remove_debug ( 
359+      fn  vec_filter_candidates ( 
363360        src :  Local , 
364361        v :  & mut  Vec < Local > , 
365-         mut  f :  impl  FnMut ( Local )  -> bool , 
362+         mut  f :  impl  FnMut ( Local )  -> CandidateFilter , 
366363        at :  Location , 
367364    )  { 
368365        v. retain ( |dest| { 
369366            let  remove = f ( * dest) ; 
370-             if  remove { 
367+             if  remove ==  CandidateFilter :: Remove   { 
371368                trace ! ( "eliminating {:?} => {:?} due to conflict at {:?}" ,  src,  dest,  at) ; 
372369            } 
373-             ! remove
370+             remove ==  CandidateFilter :: Keep 
374371        } ) ; 
375372    } 
376373
377-     /// `vec_remove_debug ` but for an `Entry` 
378-      fn  entry_remove ( 
374+     /// `vec_filter_candidates ` but for an `Entry` 
375+      fn  entry_filter_candidates ( 
379376        mut  entry :  OccupiedEntry < ' _ ,  Local ,  Vec < Local > > , 
380377        p :  Local , 
381-         f :  impl  FnMut ( Local )  -> bool , 
378+         f :  impl  FnMut ( Local )  -> CandidateFilter , 
382379        at :  Location , 
383380    )  { 
384381        let  candidates = entry. get_mut ( ) ; 
385-         Self :: vec_remove_debug ( p,  candidates,  f,  at) ; 
382+         Self :: vec_filter_candidates ( p,  candidates,  f,  at) ; 
386383        if  candidates. len ( )  == 0  { 
387384            entry. remove ( ) ; 
388385        } 
389386    } 
390387
391-     /// Removes all candidates `(p, q)` or `(q, p)` where `p` is the indicated local and `f(q)` is true. 
392-      fn  remove_candidates_if ( & mut  self ,  p :  Local ,  mut  f :  impl  FnMut ( Local )  -> bool ,  at :  Location )  { 
388+     /// For all candidates `(p, q)` or `(q, p)` removes the candidate if `f(q)` says to do so 
389+      fn  filter_candidates_by ( 
390+         & mut  self , 
391+         p :  Local , 
392+         mut  f :  impl  FnMut ( Local )  -> CandidateFilter , 
393+         at :  Location , 
394+     )  { 
393395        // Cover the cases where `p` appears as a `src` 
394396        if  let  Entry :: Occupied ( entry)  = self . c . entry ( p)  { 
395-             Self :: entry_remove ( entry,  p,  & mut  f,  at) ; 
397+             Self :: entry_filter_candidates ( entry,  p,  & mut  f,  at) ; 
396398        } 
397399        // And the cases where `p` appears as a `dest` 
398400        let  Some ( srcs)  = self . reverse . get_mut ( & p)  else  { 
@@ -401,18 +403,31 @@ impl<'alloc> Candidates<'alloc> {
401403        // We use `retain` here to remove the elements from the reverse set if we've removed the 
402404        // matching candidate in the forward set. 
403405        srcs. retain ( |src| { 
404-             if  ! f ( * src)  { 
406+             if  f ( * src)  ==  CandidateFilter :: Keep  { 
405407                return  true ; 
406408            } 
407409            let  Entry :: Occupied ( entry)  = self . c . entry ( * src)  else  { 
408410                return  false ; 
409411            } ; 
410-             Self :: entry_remove ( entry,  * src,  |dest| dest == p,  at) ; 
412+             Self :: entry_filter_candidates ( 
413+                 entry, 
414+                 * src, 
415+                 |dest| { 
416+                     if  dest == p {  CandidateFilter :: Remove  }  else  {  CandidateFilter :: Keep  } 
417+                 } , 
418+                 at, 
419+             ) ; 
411420            false 
412421        } ) ; 
413422    } 
414423} 
415424
425+ #[ derive( Copy ,  Clone ,  PartialEq ,  Eq ) ]  
426+ enum  CandidateFilter  { 
427+     Keep , 
428+     Remove , 
429+ } 
430+ 
416431impl < ' a ,  ' body ,  ' alloc ,  ' tcx >  FilterInformation < ' a ,  ' body ,  ' alloc ,  ' tcx >  { 
417432    /// Filters the set of candidates to remove those that conflict. 
418433     /// 
@@ -460,7 +475,7 @@ impl<'a, 'body, 'alloc, 'tcx> FilterInformation<'a, 'body, 'alloc, 'tcx> {
460475            for  ( i,  statement)  in  data. statements . iter ( ) . enumerate ( ) . rev ( )  { 
461476                self . at  = Location  {  block,  statement_index :  i } ; 
462477                self . live . seek_after_primary_effect ( self . at ) ; 
463-                 self . get_statement_write_info ( & statement. kind ) ; 
478+                 self . write_info . for_statement ( & statement. kind ,   self . body ) ; 
464479                self . apply_conflicts ( ) ; 
465480            } 
466481        } 
@@ -469,80 +484,59 @@ impl<'a, 'body, 'alloc, 'tcx> FilterInformation<'a, 'body, 'alloc, 'tcx> {
469484    fn  apply_conflicts ( & mut  self )  { 
470485        let  writes = & self . write_info . writes ; 
471486        for  p in  writes { 
472-             self . candidates . remove_candidates_if ( 
487+             let  other_skip = self . write_info . skip_pair . and_then ( |( a,  b) | { 
488+                 if  a == * p { 
489+                     Some ( b) 
490+                 }  else  if  b == * p { 
491+                     Some ( a) 
492+                 }  else  { 
493+                     None 
494+                 } 
495+             } ) ; 
496+             self . candidates . filter_candidates_by ( 
473497                * p, 
474-                 // It is possible that a local may be live for less than the 
475-                 // duration of a statement This happens in the case of function 
476-                 // calls or inline asm. Because of this, we also mark locals as 
477-                 // conflicting when both of them are written to in the same 
478-                 // statement. 
479-                 |q| self . live . contains ( q)  || writes. contains ( & q) , 
498+                 |q| { 
499+                     if  Some ( q)  == other_skip { 
500+                         return  CandidateFilter :: Keep ; 
501+                     } 
502+                     // It is possible that a local may be live for less than the 
503+                     // duration of a statement This happens in the case of function 
504+                     // calls or inline asm. Because of this, we also mark locals as 
505+                     // conflicting when both of them are written to in the same 
506+                     // statement. 
507+                     if  self . live . contains ( q)  || writes. contains ( & q)  { 
508+                         CandidateFilter :: Remove 
509+                     }  else  { 
510+                         CandidateFilter :: Keep 
511+                     } 
512+                 } , 
480513                self . at , 
481514            ) ; 
482515        } 
483516    } 
484- 
485-     /// Gets the write info for the `statement`. 
486-      fn  get_statement_write_info ( & mut  self ,  statement :  & StatementKind < ' tcx > )  { 
487-         self . write_info . writes . clear ( ) ; 
488-         match  statement { 
489-             StatementKind :: Assign ( box ( lhs,  rhs) )  => match  rhs { 
490-                 Rvalue :: Use ( op)  => { 
491-                     if  !lhs. is_indirect ( )  { 
492-                         self . get_assign_use_write_info ( * lhs,  op) ; 
493-                         return ; 
494-                     } 
495-                 } 
496-                 _ => ( ) , 
497-             } , 
498-             _ => ( ) , 
499-         } 
500- 
501-         self . write_info . for_statement ( statement) ; 
502-     } 
503- 
504-     fn  get_assign_use_write_info ( & mut  self ,  lhs :  Place < ' tcx > ,  rhs :  & Operand < ' tcx > )  { 
505-         // We register the writes for the operand unconditionally 
506-         self . write_info . add_operand ( rhs) ; 
507-         // However, we cannot do the same thing for the `lhs` as that would always block the 
508-         // optimization. Instead, we consider removing candidates manually. 
509-         let  Some ( rhs)  = rhs. place ( )  else  { 
510-             self . write_info . add_place ( lhs) ; 
511-             return ; 
512-         } ; 
513-         // Find out which candidate pair we should skip, if any 
514-         let  Some ( ( src,  dest) )  = places_to_candidate_pair ( lhs,  rhs,  self . body )  else  { 
515-             self . write_info . add_place ( lhs) ; 
516-             return ; 
517-         } ; 
518-         self . candidates . remove_candidates_if ( 
519-             lhs. local , 
520-             |other| { 
521-                 // Check if this is the candidate pair that should not be removed 
522-                 if  ( lhs. local  == src && other == dest)  || ( lhs. local  == dest && other == src)  { 
523-                     return  false ; 
524-                 } 
525-                 // Otherwise, do the "standard" thing 
526-                 self . live . contains ( other) 
527-             } , 
528-             self . at , 
529-         ) 
530-     } 
531517} 
532518
533519/// Describes where a statement/terminator writes to 
534520#[ derive( Default ,  Debug ) ]  
535521struct  WriteInfo  { 
536522    writes :  Vec < Local > , 
523+     /// If this pair of locals is a candidate pair, completely skip processing it during this 
524+      /// statement. All other candidates are unaffected. 
525+      skip_pair :  Option < ( Local ,  Local ) > , 
537526} 
538527
539528impl  WriteInfo  { 
540-     fn  for_statement < ' tcx > ( & mut  self ,  statement :  & StatementKind < ' tcx > )  { 
529+     fn  for_statement < ' tcx > ( & mut  self ,  statement :  & StatementKind < ' tcx > ,  body :  & Body < ' tcx > )  { 
530+         self . reset ( ) ; 
541531        match  statement { 
542532            StatementKind :: Assign ( box ( lhs,  rhs) )  => { 
543533                self . add_place ( * lhs) ; 
544534                match  rhs { 
545-                     Rvalue :: Use ( op)  | Rvalue :: Repeat ( op,  _)  => { 
535+                     Rvalue :: Use ( op)  => { 
536+                         self . add_operand ( op) ; 
537+                         self . consider_skipping_for_assign_use ( * lhs,  op,  body) ; 
538+                     } 
539+                     Rvalue :: Repeat ( op,  _)  => { 
546540                        self . add_operand ( op) ; 
547541                    } 
548542                    Rvalue :: Cast ( _,  op,  _) 
@@ -586,8 +580,22 @@ impl WriteInfo {
586580        } 
587581    } 
588582
583+     fn  consider_skipping_for_assign_use < ' tcx > ( 
584+         & mut  self , 
585+         lhs :  Place < ' tcx > , 
586+         rhs :  & Operand < ' tcx > , 
587+         body :  & Body < ' tcx > , 
588+     )  { 
589+         let  Some ( rhs)  = rhs. place ( )  else  { 
590+             return 
591+         } ; 
592+         if  let  Some ( pair)  = places_to_candidate_pair ( lhs,  rhs,  body)  { 
593+             self . skip_pair  = Some ( pair) ; 
594+         } 
595+     } 
596+ 
589597    fn  for_terminator < ' tcx > ( & mut  self ,  terminator :  & TerminatorKind < ' tcx > )  { 
590-         self . writes . clear ( ) ; 
598+         self . reset ( ) ; 
591599        match  terminator { 
592600            TerminatorKind :: SwitchInt  {  discr :  op,  .. } 
593601            | TerminatorKind :: Assert  {  cond :  op,  .. }  => { 
@@ -657,15 +665,16 @@ impl WriteInfo {
657665            Operand :: Copy ( _)  | Operand :: Constant ( _)  => ( ) , 
658666        } 
659667    } 
668+ 
669+     fn  reset ( & mut  self )  { 
670+         self . writes . clear ( ) ; 
671+         self . skip_pair  = None ; 
672+     } 
660673} 
661674
662675///////////////////////////////////////////////////// 
663676// Candidate accumulation 
664677
665- fn  is_constant < ' tcx > ( place :  Place < ' tcx > )  -> bool  { 
666-     place. projection . iter ( ) . all ( |p| !matches ! ( p,  ProjectionElem :: Deref  | ProjectionElem :: Index ( _) ) ) 
667- } 
668- 
669678/// If the pair of places is being considered for merging, returns the candidate which would be 
670679/// merged in order to accomplish this. 
671680/// 
@@ -741,10 +750,6 @@ impl<'tcx> Visitor<'tcx> for FindAssignments<'_, '_, 'tcx> {
741750            Rvalue :: Use ( Operand :: Copy ( rhs)  | Operand :: Move ( rhs) ) , 
742751        ) )  = & statement. kind 
743752        { 
744-             if  !is_constant ( * lhs)  || !is_constant ( * rhs)  { 
745-                 return ; 
746-             } 
747- 
748753            let  Some ( ( src,  dest) )  = places_to_candidate_pair ( * lhs,  * rhs,  self . body )  else  { 
749754                return ; 
750755            } ; 
0 commit comments