diff --git a/hugr-passes/src/composable.rs b/hugr-passes/src/composable.rs index 2a516c5cb..857eec4c8 100644 --- a/hugr-passes/src/composable.rs +++ b/hugr-passes/src/composable.rs @@ -36,21 +36,8 @@ pub trait ComposablePass: Sized { /// /// See [`PassScope`] for more details. /// - /// In `hugr 0.25.*`, this configuration is only a guidance, and may be - /// ignored by the pass by using the default implementation. - /// - /// From `hugr >=0.26.0`, passes must respect the scope configuration. - // - // For hugr passes, this is tracked by - fn with_scope_internal(self, scope: impl Into) -> Self { - // Currently passes are not required to respect the scope configuration. - // - // - // deprecated: Remove default implementation in hugr 0.26.0, - // ensure all passes follow the scope configuration. - let _ = scope; - self - } + /// Since `hugr >=0.26.0`, passes must implement this to respect the scope configuration. + fn with_scope_internal(self, scope: impl Into) -> Self; /// Apply a function to the error type of this pass, returning a new /// [`ComposablePass`] that has the same result type. diff --git a/hugr-passes/src/inline_dfgs.rs b/hugr-passes/src/inline_dfgs.rs index a7c6aab0e..48e8f90c9 100644 --- a/hugr-passes/src/inline_dfgs.rs +++ b/hugr-passes/src/inline_dfgs.rs @@ -1,32 +1,33 @@ //! Provides [`InlineDFGsPass`], a pass for inlining all DFGs in a Hugr. use std::convert::Infallible; -use hugr_core::{ - Node, - hugr::{ - hugrmut::HugrMut, - patch::inline_dfg::{InlineDFG, InlineDFGError}, - }, +use hugr_core::hugr::{ + hugrmut::HugrMut, + patch::inline_dfg::{InlineDFG, InlineDFGError}, }; use itertools::Itertools; -use crate::ComposablePass; +use crate::{ComposablePass, PassScope}; /// Inlines all DFG nodes nested below the entrypoint. /// /// See [InlineDFG] for a rewrite to inline single DFGs. -#[derive(Debug, Clone)] -pub struct InlineDFGsPass; +#[derive(Debug, Default, Clone)] +pub struct InlineDFGsPass { + scope: PassScope, +} -impl> ComposablePass for InlineDFGsPass { +impl ComposablePass for InlineDFGsPass { type Error = Infallible; type Result = (); fn run(&self, h: &mut H) -> Result<(), Self::Error> { + let Some(r) = self.scope.root(h) else { + return Ok(()); + }; let dfgs = h - .entry_descendants() - .skip(1) // Skip the entrypoint itself - .filter(|&n| h.get_optype(n).is_dfg()) + .descendants(r) + .filter(|&n| n != h.entrypoint() && h.get_optype(n).is_dfg()) .collect_vec(); for dfg in dfgs { h.apply_patch(InlineDFG(dfg.into())) @@ -43,6 +44,11 @@ impl> ComposablePass for InlineDFGsPass { } Ok(()) } + + fn with_scope_internal(mut self, scope: impl Into) -> Self { + self.scope = scope.into(); + self + } } #[cfg(test)] @@ -84,7 +90,7 @@ mod test { let mut h = outer.finish_hugr_with_outputs([a, b])?; assert_eq!(h.num_nodes(), 5 * 3 + 4); // 5 DFGs with I/O + 4 nodes for module/func roots - InlineDFGsPass.run(&mut h).unwrap(); + InlineDFGsPass::default().run(&mut h).unwrap(); // Root should be the only remaining DFG assert!(h.get_optype(h.entrypoint()).is_dfg()); diff --git a/hugr-passes/src/lib.rs b/hugr-passes/src/lib.rs index 99a081a44..ceb58539c 100644 --- a/hugr-passes/src/lib.rs +++ b/hugr-passes/src/lib.rs @@ -38,6 +38,11 @@ pub use lower::{lower_ops, replace_many_ops}; #[expect(deprecated)] // Remove together pub use monomorphize::monomorphize; pub use monomorphize::{MonomorphizePass, mangle_name}; -pub use non_local::{ensure_no_nonlocal_edges, nonlocal_edges}; +#[deprecated( + note = "Use LocalizeEdgesPass::check_no_nonlocal_edges", + since = "0.26.0" +)] +#[expect(deprecated)] // Remove at same time +pub use non_local::ensure_no_nonlocal_edges; pub use replace_types::ReplaceTypes; pub use untuple::UntuplePass; diff --git a/hugr-passes/src/monomorphize.rs b/hugr-passes/src/monomorphize.rs index 9382a1d07..8669481fa 100644 --- a/hugr-passes/src/monomorphize.rs +++ b/hugr-passes/src/monomorphize.rs @@ -13,8 +13,8 @@ use hugr_core::{ use hugr_core::hugr::{HugrView, OpType, hugrmut::HugrMut}; use itertools::Itertools as _; -use crate::ComposablePass; use crate::composable::{ValidatePassError, validate_if_test}; +use crate::{ComposablePass, PassScope}; /// Replaces calls to polymorphic functions with calls to new monomorphic /// instantiations of the polymorphic ones. @@ -36,7 +36,7 @@ use crate::composable::{ValidatePassError, validate_if_test}; pub fn monomorphize( hugr: &mut impl HugrMut, ) -> Result<(), ValidatePassError> { - validate_if_test(MonomorphizePass, hugr) + validate_if_test(MonomorphizePass::default(), hugr) } fn is_polymorphic(fd: &FuncDefn) -> bool { @@ -197,21 +197,37 @@ fn instantiate( /// children of the root node. We make best effort to ensure that names (derived /// from parent function names and concrete type args) of new functions are unique /// whenever the names of their parents are unique, but this is not guaranteed. -#[derive(Debug, Clone)] -pub struct MonomorphizePass; +#[derive(Debug, Default, Clone)] +pub struct MonomorphizePass { + scope: PassScope, +} impl> ComposablePass for MonomorphizePass { type Error = Infallible; type Result = (); fn run(&self, h: &mut H) -> Result<(), Self::Error> { - let root = h.entrypoint(); - // If the root is a polymorphic function, then there are no external calls, so nothing to do - if !is_polymorphic_funcdefn(h.get_optype(root)) { - mono_scan(h, root, None, &mut HashMap::new()); - } + match self.scope { + PassScope::EntrypointFlat | PassScope::EntrypointRecursive => { + // for module-entrypoint, PassScope says to do nothing. (Monomorphization could.) + // for non-module-entrypoint, PassScope says not to touch Hugr outside entrypoint, + // so monomorphization cannot add any new functions --> do nothing. + // NOTE we could look to see if there are any existing instantations that + // we could use (!), but not atm. + } + PassScope::Global(_) => + // only generates new nodes, never changes signature of any existing node. + { + mono_scan(h, h.module_root(), None, &mut HashMap::new()) + } + }; Ok(()) } + + fn with_scope_internal(mut self, scope: impl Into) -> Self { + self.scope = scope.into(); + self + } } /// Helper to create mangled representations of lists of [TypeArg]s. @@ -308,7 +324,7 @@ mod test { let [i1] = dfg_builder.input_wires_arr(); let hugr = dfg_builder.finish_hugr_with_outputs([i1]).unwrap(); let mut hugr2 = hugr.clone(); - MonomorphizePass.run(&mut hugr2).unwrap(); + MonomorphizePass::default().run(&mut hugr2).unwrap(); assert_eq!(hugr, hugr2); } @@ -374,7 +390,7 @@ mod test { .count(), 3 ); - MonomorphizePass.run(&mut hugr)?; + MonomorphizePass::default().run(&mut hugr)?; let mono = hugr; mono.validate()?; @@ -395,7 +411,7 @@ mod test { ["double", "main", "triple"] ); let mut mono2 = mono.clone(); - MonomorphizePass.run(&mut mono2)?; + MonomorphizePass::default().run(&mut mono2)?; assert_eq!(mono2, mono); // Idempotent @@ -551,7 +567,7 @@ mod test { let mut hugr = outer.finish_hugr_with_outputs([e1, e2]).unwrap(); hugr.set_entrypoint(hugr.module_root()); // We want to act on everything, not just `main` - MonomorphizePass.run(&mut hugr).unwrap(); + MonomorphizePass::default().run(&mut hugr).unwrap(); let mono_hugr = hugr; mono_hugr.validate().unwrap(); let funcs = list_funcs(&mono_hugr); @@ -629,7 +645,7 @@ mod test { module_builder.finish_hugr().unwrap() }; - MonomorphizePass.run(&mut hugr).unwrap(); + MonomorphizePass::default().run(&mut hugr).unwrap(); RemoveDeadFuncsPass::default() .with_scope(Preserve::Public) .run(&mut hugr) diff --git a/hugr-passes/src/non_local.rs b/hugr-passes/src/non_local.rs index c8f5e850d..ae12952a0 100644 --- a/hugr-passes/src/non_local.rs +++ b/hugr-passes/src/non_local.rs @@ -1,21 +1,27 @@ //! This module provides functions for finding non-local edges //! in a Hugr and converting them to local edges. -use itertools::Itertools as _; - use hugr_core::{ HugrView, IncomingPort, Wire, hugr::hugrmut::HugrMut, types::{EdgeKind, Type}, }; -use crate::ComposablePass; +use crate::{ComposablePass, PassScope, composable::Preserve}; mod localize; use localize::ExtraSourceReqs; -/// [ComposablePass] wrapper for [remove_nonlocal_edges] -#[derive(Clone, Debug, Hash)] -pub struct LocalizeEdges; +/// Converts non-local edges in a Hugr into local ones, by inserting extra inputs to container +/// nodes and extra outports to Input nodes (and conversely to outputs of [DataflowBlock]s). +/// +/// Ignores [PassScope::recursive], as acts only on nonlocal edges *both* of whose endpoints +/// are within the subtree specified by [PassScope::root]. +/// +/// [DataflowBlock]: hugr_core::ops::DataflowBlock +#[derive(Clone, Debug, Default, Hash)] +pub struct LocalizeEdges { + scope: PassScope, +} /// Error from [LocalizeEdges] or [remove_nonlocal_edges] #[derive(derive_more::Error, derive_more::Display, derive_more::From, Debug, PartialEq)] @@ -28,23 +34,55 @@ impl ComposablePass for LocalizeEdges { type Result = (); fn run(&self, hugr: &mut H) -> Result { - remove_nonlocal_edges(hugr) + // Group all the non-local edges in the graph by target node, + // storing for each the source and type (well-defined as these are Value edges). + let edges = match self.check_no_nonlocal_edges(hugr) { + Ok(()) => return Ok(()), + Err(FindNonLocalEdgesError::Edges(edges)) => edges, + }; + + let nonlocal_edges: Vec<_> = edges + .into_iter() + .map(|(node, inport)| { + // unwrap because nonlocal_edges(hugr) already skips in-ports with !=1 linked outputs. + let (src_n, outp) = hugr.single_linked_output(node, inport).unwrap(); + debug_assert!(hugr.get_parent(src_n).unwrap() != hugr.get_parent(node).unwrap()); + let Some(EdgeKind::Value(ty)) = hugr.get_optype(src_n).port_kind(outp) else { + panic!("impossible") + }; + (node, (Wire::new(src_n, outp), ty)) + }) + .collect(); + + // We now compute the sources needed by each parent node. + let needs_sources_map = { + let mut bnsm = ExtraSourceReqs::default(); + for (target_node, (source, ty)) in nonlocal_edges.iter() { + let parent = hugr.get_parent(*target_node).unwrap(); + debug_assert!(hugr.get_parent(parent).is_some()); + bnsm.add_edge(&*hugr, parent, *source, ty.clone()); + } + bnsm + }; + + debug_assert!(nonlocal_edges.iter().all(|(n, (source, _))| { + let source_parent = hugr.get_parent(source.node()).unwrap(); + let source_gp = hugr.get_parent(source_parent); + ancestors(*n, hugr) + .skip(1) + .take_while(|&a| a != source_parent && source_gp.is_none_or(|gp| a != gp)) + .all(|parent| needs_sources_map.parent_needs(parent, *source)) + })); + + needs_sources_map.thread_hugr(hugr); + + Ok(()) } -} -/// Returns an iterator over all non local edges in a Hugr beneath the entrypoint. -/// -/// All `(node, in_port)` pairs are returned where `in_port` is a value port connected to a -/// node whose parent is both beneath the entrypoint and different from the parent of `node`. -pub fn nonlocal_edges(hugr: &H) -> impl Iterator + '_ { - hugr.entry_descendants().flat_map(move |node| { - hugr.in_value_types(node).filter_map(move |(in_p, _)| { - let (src, _) = hugr.single_linked_output(node, in_p)?; - (hugr.get_parent(node) != hugr.get_parent(src) - && ancestors(src, hugr).any(|a| a == hugr.entrypoint())) - .then_some((node, in_p)) - }) - }) + fn with_scope_internal(mut self, scope: impl Into) -> Self { + self.scope = scope.into(); + self + } } /// An error from [ensure_no_nonlocal_edges] @@ -57,69 +95,63 @@ pub enum FindNonLocalEdgesError { Edges(Vec<(N, IncomingPort)>), } -/// Verifies that there are no non local value edges in the Hugr. +/// Verifies that there are no non local value edges in the Hugr beneath the entrypoint. +#[deprecated(note = "Use LocalizeEdges::check_no_nonlocal_edges", since = "0.26.0")] pub fn ensure_no_nonlocal_edges( hugr: &H, ) -> Result<(), FindNonLocalEdgesError> { - let non_local_edges: Vec<_> = nonlocal_edges(hugr).collect_vec(); - if non_local_edges.is_empty() { - Ok(()) - } else { - Err(FindNonLocalEdgesError::Edges(non_local_edges))? - } -} - -fn just_types<'a, X: 'a>(v: impl IntoIterator) -> impl Iterator { - v.into_iter().map(|(_, t)| t.clone()) + LocalizeEdges::new_for_hugr(hugr).check_no_nonlocal_edges(hugr) } -/// Converts all non-local edges in a Hugr into local ones, by inserting extra inputs -/// to container nodes and extra outports to Input nodes (and conversely to outputs of -/// [DataflowBlock]s). -/// -/// [DataflowBlock]: hugr_core::ops::DataflowBlock -pub fn remove_nonlocal_edges(hugr: &mut H) -> Result<(), LocalizeEdgesError> { - // Group all the non-local edges in the graph by target node, - // storing for each the source and type (well-defined as these are Value edges). - let nonlocal_edges: Vec<_> = nonlocal_edges(hugr) - .map(|(node, inport)| { - // unwrap because nonlocal_edges(hugr) already skips in-ports with !=1 linked outputs. - let (src_n, outp) = hugr.single_linked_output(node, inport).unwrap(); - debug_assert!(hugr.get_parent(src_n).unwrap() != hugr.get_parent(node).unwrap()); - let Some(EdgeKind::Value(ty)) = hugr.get_optype(src_n).port_kind(outp) else { - panic!("impossible") - }; - (node, (Wire::new(src_n, outp), ty)) - }) - .collect(); - - if nonlocal_edges.is_empty() { - return Ok(()); +impl LocalizeEdges { + /// Create a new instance that works beneath the entrypoint only. + fn new_for_hugr(h: &impl HugrView) -> Self { + let scope = if h.entrypoint() == h.module_root() { + Preserve::Entrypoint.into() + } else { + PassScope::EntrypointRecursive + }; + Self { scope } } - // We now compute the sources needed by each parent node. - let needs_sources_map = { - let mut bnsm = ExtraSourceReqs::default(); - for (target_node, (source, ty)) in nonlocal_edges.iter() { - let parent = hugr.get_parent(*target_node).unwrap(); - debug_assert!(hugr.get_parent(parent).is_some()); - bnsm.add_edge(&*hugr, parent, *source, ty.clone()); + /// Verifies that there are no non local value edges in the Hugr beneath the + /// [PassScope::root] (ignoring [PassScope::recursive]). + pub fn check_no_nonlocal_edges( + &self, + hugr: &H, + ) -> Result<(), FindNonLocalEdgesError> { + let non_local_edges: Vec<_> = self.nonlocal_edges(hugr).collect(); + if non_local_edges.is_empty() { + Ok(()) + } else { + Err(FindNonLocalEdgesError::Edges(non_local_edges))? } - bnsm - }; + } - debug_assert!(nonlocal_edges.iter().all(|(n, (source, _))| { - let source_parent = hugr.get_parent(source.node()).unwrap(); - let source_gp = hugr.get_parent(source_parent); - ancestors(*n, hugr) - .skip(1) - .take_while(|&a| a != source_parent && source_gp.is_none_or(|gp| a != gp)) - .all(|parent| needs_sources_map.parent_needs(parent, *source)) - })); + fn nonlocal_edges<'b, H: HugrView>( + &self, + hugr: &'b H, + ) -> impl Iterator + use<'b, H> { + self.scope.root(hugr).into_iter().flat_map(move |root| { + hugr.descendants(root).flat_map(move |node| { + hugr.in_value_types(node).filter_map(move |(in_p, _)| { + let (src, _) = hugr.single_linked_output(node, in_p)?; + (hugr.get_parent(node) != hugr.get_parent(src) + && ancestors(src, hugr).any(|a| a == root)) + .then_some((node, in_p)) + }) + }) + }) + } +} - needs_sources_map.thread_hugr(hugr); +fn just_types<'a, X: 'a>(v: impl IntoIterator) -> impl Iterator { + v.into_iter().map(|(_, t)| t.clone()) +} - Ok(()) +/// Runs a [LocalizeEdges] pass on the entrypoint subtree of a hugr. +pub fn remove_nonlocal_edges(hugr: &mut H) -> Result<(), LocalizeEdgesError> { + LocalizeEdges::new_for_hugr(hugr).run(hugr) } fn ancestors(n: H::Node, h: &H) -> impl Iterator { @@ -128,6 +160,9 @@ fn ancestors(n: H::Node, h: &H) -> impl Iterator { #[cfg(test)] mod test { + use itertools::Itertools as _; + use rstest::rstest; + use hugr_core::{ builder::{DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, SubContainer}, extension::prelude::{Noop, bool_t, either_type}, @@ -136,7 +171,8 @@ mod test { type_row, types::Signature, }; - use rstest::rstest; + + use crate::composable::WithScope; use super::*; @@ -151,7 +187,9 @@ mod test { .outputs_arr(); builder.finish_hugr_with_outputs([out_w]).unwrap() }; - ensure_no_nonlocal_edges(&hugr).unwrap(); + LocalizeEdges::default() + .check_no_nonlocal_edges(&hugr) + .unwrap(); } #[test] @@ -178,7 +216,10 @@ mod test { (builder.finish_hugr_with_outputs([out_w]).unwrap(), edge) }; assert_eq!( - ensure_no_nonlocal_edges(&hugr).unwrap_err(), + LocalizeEdges::default() + .with_scope(PassScope::EntrypointRecursive) + .check_no_nonlocal_edges(&hugr) + .unwrap_err(), FindNonLocalEdgesError::Edges(vec![edge]) ); } @@ -203,10 +244,11 @@ mod test { }; outer.finish_hugr_with_outputs(inner_outs).unwrap() }; - assert!(ensure_no_nonlocal_edges(&hugr).is_err()); + let pass = LocalizeEdges::default().with_scope(PassScope::EntrypointFlat); + assert!(pass.check_no_nonlocal_edges(&hugr).is_err()); remove_nonlocal_edges(&mut hugr).unwrap(); hugr.validate().unwrap(); - assert!(ensure_no_nonlocal_edges(&hugr).is_ok()); + assert!(pass.check_no_nonlocal_edges(&hugr).is_ok()); } #[test] @@ -246,10 +288,11 @@ mod test { }; outer.finish_hugr_with_outputs([s1, s2, s3]).unwrap() }; - assert!(ensure_no_nonlocal_edges(&hugr).is_err()); + let pass = LocalizeEdges::default(); + assert!(pass.check_no_nonlocal_edges(&hugr).is_err()); remove_nonlocal_edges(&mut hugr).unwrap(); hugr.validate().unwrap_or_else(|e| panic!("{e}")); - assert!(ensure_no_nonlocal_edges(&hugr).is_ok()); + assert!(pass.check_no_nonlocal_edges(&hugr).is_ok()); } #[test] @@ -298,10 +341,11 @@ mod test { }; outer.finish_hugr_with_outputs([out]).unwrap() }; - assert!(ensure_no_nonlocal_edges(&hugr).is_err()); + let pass = LocalizeEdges::default(); + assert!(pass.check_no_nonlocal_edges(&hugr).is_err()); remove_nonlocal_edges(&mut hugr).unwrap(); hugr.validate().unwrap_or_else(|e| panic!("{e}")); - assert!(ensure_no_nonlocal_edges(&hugr).is_ok()); + assert!(pass.check_no_nonlocal_edges(&hugr).is_ok()); } #[test] @@ -411,7 +455,8 @@ mod test { let [unit, out] = cfg.finish_sub_container().unwrap().outputs_arr(); let mut hugr = outer.finish_hugr_with_outputs([unit, out]).unwrap(); - let Err(FindNonLocalEdgesError::Edges(es)) = ensure_no_nonlocal_edges(&hugr) else { + let pass = LocalizeEdges::default(); + let Err(FindNonLocalEdgesError::Edges(es)) = pass.check_no_nonlocal_edges(&hugr) else { panic!() }; assert_eq!( @@ -423,7 +468,7 @@ mod test { ); remove_nonlocal_edges(&mut hugr).unwrap(); hugr.validate().unwrap(); - assert!(ensure_no_nonlocal_edges(&hugr).is_ok()); + assert!(pass.check_no_nonlocal_edges(&hugr).is_ok()); let dfb = |bb: BasicBlockID| hugr.get_optype(bb.node()).as_dataflow_block().unwrap(); // Entry node gets ext_edge_type added, only assert_eq!( diff --git a/hugr-passes/src/redundant_order_edges.rs b/hugr-passes/src/redundant_order_edges.rs index a62d0e28e..dba6c5ed3 100644 --- a/hugr-passes/src/redundant_order_edges.rs +++ b/hugr-passes/src/redundant_order_edges.rs @@ -10,7 +10,7 @@ use hugr_core::{HugrView, IncomingPort, Node, OutgoingPort}; use itertools::Itertools; use petgraph::visit::Walker; -use crate::ComposablePass; +use crate::{ComposablePass, PassScope}; /// A pass for removing order edges in a Hugr region that are already implied by /// other order or dataflow dependencies. @@ -21,10 +21,10 @@ use crate::ComposablePass; /// Each evaluation on a region runs in `O(e + n log(n) * #order_edges)` time, /// where `e` and `n` are the number of edges and nodes in the region, /// respectively. -#[derive(Default, Debug, Clone, Copy)] +#[derive(Debug, Default, Clone)] pub struct RedundantOrderEdgesPass { - /// Whether to traverse the HUGR recursively. - recursive: bool, + /// On what part of the Hugr to run + scope: PassScope, } /// Result type for the redundant order edges pass. @@ -35,17 +35,6 @@ pub struct RedundantOrderEdgesResult { } impl RedundantOrderEdgesPass { - /// Create a new redundant order edges pass with the given configuration. - pub fn new() -> Self { - Self { recursive: true } - } - - /// Sets whether the pass should traverse the HUGR recursively. - pub fn recursive(mut self, recursive: bool) -> Self { - self.recursive = recursive; - self - } - /// Evaluate the pass on the given dataflow region. /// /// # Arguments @@ -79,23 +68,23 @@ impl RedundantOrderEdgesPass { // Traverse the region in topological order. let (region, node_map) = hugr.region_portgraph(parent); let postorder = petgraph::visit::Topo::new(®ion); - for pg_node in postorder.iter(®ion) { - let node = node_map.from_portgraph(pg_node); - let op = hugr.get_optype(node); + for pg_child in postorder.iter(®ion) { + let child = node_map.from_portgraph(pg_child); + let op = hugr.get_optype(child); - // If the node has children and we are running recursively, add the children to the region candidates. - if self.recursive && hugr.first_child(node).is_some() { - region_candidates.extend(hugr.children(node)); + // If the child itself is a region (parent) and we are running recursively, add the child to the region candidates. + if self.scope.recursive() && hugr.first_child(child).is_some() { + region_candidates.push_back(child); } - let predecessor_edges = predecessor_order_edges.remove(&node).unwrap_or_default(); + let predecessor_edges = predecessor_order_edges.remove(&child).unwrap_or_default(); // If we have reached the target of an order edge by exploring // connected nodes from the source, then mark the order edge for // removal. let removable_edges: HashSet> = predecessor_edges .iter() - .filter(|edge| edge.to_node == node) + .filter(|edge| edge.to_node == child) .copied() .collect(); @@ -110,13 +99,13 @@ impl RedundantOrderEdgesPass { // The latter may be necessary for keeping external edges valid. let new_edges = match op.other_output_port() { Some(out_order_port) => hugr - .linked_inputs(node, out_order_port) + .linked_inputs(child, out_order_port) .filter(|(to_node, _)| { hugr.get_parent(*to_node) == Some(parent) && hugr.first_child(*to_node).is_none() }) .map(|(to_node, to_port)| PredecessorOrderEdges { - from_node: node, + from_node: child, from_port: out_order_port, to_node, to_port, @@ -127,7 +116,7 @@ impl RedundantOrderEdgesPass { // Add the order edges to the `predecessor_order_edges` of the forward neighbors of the node. for out_port in op.value_output_ports().chain(op.static_output_port()) { - for (to_node, _) in hugr.linked_inputs(node, out_port) { + for (to_node, _) in hugr.linked_inputs(child, out_port) { if hugr.get_parent(to_node) != Some(parent) { continue; } @@ -139,7 +128,7 @@ impl RedundantOrderEdgesPass { } // Do not propagate new order edges through themselves (otherwise we'd always remove them). if let Some(out_port) = op.other_output_port() { - for (to_node, _) in hugr.linked_inputs(node, out_port) { + for (to_node, _) in hugr.linked_inputs(child, out_port) { if hugr.get_parent(to_node) != Some(parent) { continue; } @@ -167,9 +156,14 @@ impl> ComposablePass for RedundantOrderEdgesPass { type Error = HugrError; type Result = RedundantOrderEdgesResult; + fn with_scope_internal(mut self, scope: impl Into) -> Self { + self.scope = scope.into(); + self + } + fn run(&self, hugr: &mut H) -> Result { // Nodes to explore in the hugr. - let mut region_candidates = VecDeque::from_iter([hugr.entrypoint()]); + let mut region_candidates = VecDeque::from_iter(self.scope.root(hugr)); let mut result = RedundantOrderEdgesResult::default(); while let Some(region) = region_candidates.pop_front() { @@ -177,8 +171,7 @@ impl> ComposablePass for RedundantOrderEdgesPass { if OpTag::DataflowParent.is_superset(op.tag()) { result += self.run_on_df_region(hugr, region, &mut region_candidates)?; - } else { - // When exploring non-dataflow regions, add the children recursively (independently of self.recursive). + } else if self.scope.recursive() { region_candidates.extend(hugr.children(region)); } } @@ -250,7 +243,7 @@ mod tests { let mut hugr = hugr.finish_hugr_with_outputs([noop5.out_wire(0)]).unwrap(); // Run the pass - let result = RedundantOrderEdgesPass::new().run(&mut hugr).unwrap(); + let result = RedundantOrderEdgesPass::default().run(&mut hugr).unwrap(); assert_eq!(result.edges_removed, 2); // Check that we removed the correct order edges.