diff --git a/hugr-passes/src/composable.rs b/hugr-passes/src/composable.rs index dc2f796f7..5404cdba3 100644 --- a/hugr-passes/src/composable.rs +++ b/hugr-passes/src/composable.rs @@ -307,6 +307,7 @@ impl< } } +// Note remove when deprecated constant_fold_pass / remove_dead_funcs are removed pub(crate) fn validate_if_test, H: HugrMut>( pass: P, hugr: &mut H, @@ -319,7 +320,7 @@ pub(crate) fn validate_if_test, H: HugrMut>( } #[cfg(test)] -mod test { +pub(crate) mod test { use hugr_core::ops::Value; use itertools::{Either, Itertools}; @@ -333,12 +334,20 @@ mod test { use hugr_core::types::{Signature, TypeRow}; use hugr_core::{Hugr, HugrView, IncomingPort, Node, NodeIndex}; + use crate::composable::WithScope; use crate::const_fold::{ConstFoldError, ConstantFoldPass}; use crate::dead_code::DeadCodeElimError; - use crate::untuple::{UntupleRecursive, UntupleResult}; - use crate::{DeadCodeElimPass, ReplaceTypes, UntuplePass}; + use crate::untuple::UntupleResult; + use crate::{DeadCodeElimPass, PassScope, ReplaceTypes, UntuplePass}; - use super::{ComposablePass, IfThen, ValidatePassError, ValidatingPass, validate_if_test}; + use super::{ComposablePass, IfThen, ValidatePassError, ValidatingPass}; + + pub(crate) fn run_validating, H: HugrMut>( + pass: P, + hugr: &mut H, + ) -> Result> { + ValidatingPass::new(pass).run(hugr) + } #[test] fn test_then() { @@ -444,7 +453,7 @@ mod test { fb.finish_hugr_with_outputs(untup.outputs()).unwrap() }; - let untup = UntuplePass::new(UntupleRecursive::Recursive); + let untup = UntuplePass::default().with_scope(PassScope::EntrypointRecursive); { // Change usize_t to INT_TYPES[6], and if that did anything (it will!), then Untuple let mut repl = ReplaceTypes::default(); @@ -453,7 +462,7 @@ mod test { let ifthen = IfThen::, _, _, _>::new(repl, untup.clone()); let mut h = h.clone(); - let r = validate_if_test(ifthen, &mut h).unwrap(); + let r = run_validating(ifthen, &mut h).unwrap(); assert_eq!( r, Some(UntupleResult { @@ -470,7 +479,7 @@ mod test { repl.set_replace_type(i32_custom_t, INT_TYPES[6].clone()); let ifthen = IfThen::, _, _, _>::new(repl, untup); let mut h = h; - let r = validate_if_test(ifthen, &mut h).unwrap(); + let r = run_validating(ifthen, &mut h).unwrap(); assert_eq!(r, None); assert_eq!(h.children(h.entrypoint()).count(), 4); let mktup = h diff --git a/hugr-passes/src/composable/scope.rs b/hugr-passes/src/composable/scope.rs index 15a72766d..149d0cbf0 100644 --- a/hugr-passes/src/composable/scope.rs +++ b/hugr-passes/src/composable/scope.rs @@ -70,6 +70,9 @@ pub enum PassScope { /// also name (if public) for linking; and whether the node is a valid dataflow child /// or is a [DataflowBlock], [ExitBlock] or [Module]). /// +/// For lowering passes (whose goal is to change the interface!), generally this has no +/// effect. +/// /// [DataflowBlock]: OpType::DataflowBlock /// [ExitBlock]: OpType::ExitBlock /// [Module]: OpType::Module diff --git a/hugr-passes/src/const_fold.rs b/hugr-passes/src/const_fold.rs index 82caf3dd7..5dd9d24d4 100644 --- a/hugr-passes/src/const_fold.rs +++ b/hugr-passes/src/const_fold.rs @@ -20,15 +20,16 @@ use crate::dataflow::{ partial_from_const, }; use crate::dead_code::{DeadCodeElimError, DeadCodeElimPass, PreserveNode}; -use crate::{ComposablePass, composable::validate_if_test}; +use crate::{ComposablePass, PassScope, composable::validate_if_test}; #[derive(Debug, Clone, Default)] /// A configuration for the Constant Folding pass. pub struct ConstantFoldPass { allow_increase_termination: bool, + scope: Option, /// Each outer key Node must be either: /// - a `FuncDefn` child of the root, if the root is a module; or - /// - the root, if the root is not a Module + /// - the entrypoint, if the entrypoint is not a Module inputs: HashMap>, } @@ -67,8 +68,8 @@ impl ConstantFoldPass { } /// Specifies a number of external inputs to an entry point of the Hugr. - /// In normal use, for Module-rooted Hugrs, `node` is a `FuncDefn` child of the root; - /// or for non-Module-rooted Hugrs, `node` is the root of the Hugr. (This is not + /// In normal use, for Module-entrypoint Hugrs, `node` is a `FuncDefn` child of the module; + /// or for non-Module-entrypoint Hugrs, `node` is the entrypoint of the Hugr. (This is not /// enforced, but it must be a container and not a module itself.) /// /// Multiple calls for the same entry-point combine their values, with later @@ -100,6 +101,13 @@ impl + 'static> ComposablePass for ConstantFoldPass { /// [`ConstFoldError::InvalidEntryPoint`] if an entry-point added by [`Self::with_inputs`] /// was of an invalid [`OpType`] fn run(&self, hugr: &mut H) -> Result<(), ConstFoldError> { + let Some(root) = self + .scope + .as_ref() + .map_or(Some(hugr.entrypoint()), |sc| sc.root(hugr)) + else { + return Ok(()); // Scope says do nothing + }; let fresh_node = Node::from(portgraph::NodeIndex::new( hugr.nodes().max().map_or(0, |n| n.index() + 1), )); @@ -122,15 +130,24 @@ impl + 'static> ComposablePass for ConstantFoldPass { .map_err(|op| ConstFoldError::InvalidEntryPoint { node: n, op })?; } + for node in self.scope.iter().flat_map(|sc| sc.preserve_interface(hugr)) { + if node == hugr.module_root() || self.inputs.contains_key(&node) { + // Cannot prepopulate inputs for module-root; do not `join` with inputs explicitly specified. + continue; + } + const NO_INPUTS: [(IncomingPort, PartialValue); 0] = []; + m.prepopulate_inputs(node, NO_INPUTS) + .map_err(|op| ConstFoldError::InvalidEntryPoint { node, op })?; + } + let results = m.run(ConstFoldContext, []); let mb_root_inp = hugr.get_io(hugr.entrypoint()).map(|[i, _]| i); let wires_to_break = hugr - .entry_descendants() + .descendants(root) .flat_map(|n| hugr.node_inputs(n).map(move |ip| (n, ip))) .filter(|(n, ip)| { - *n != hugr.entrypoint() - && matches!(hugr.get_optype(*n).port_kind(*ip), Some(EdgeKind::Value(_))) + *n != root && matches!(hugr.get_optype(*n).port_kind(*ip), Some(EdgeKind::Value(_))) }) .filter_map(|(n, ip)| { let (src, outp) = hugr.single_linked_output(n, ip).unwrap(); @@ -165,8 +182,13 @@ impl + 'static> ComposablePass for ConstantFoldPass { hugr.connect(lcst, OutgoingPort::from(0), n, inport); } // Eliminate dead code not required for the same entry points. - DeadCodeElimPass::::default() - .with_entry_points(self.inputs.keys().copied()) + let dce = self + .scope + .as_ref() + .map_or(DeadCodeElimPass::::default(), |scope| { + DeadCodeElimPass::::default().with_scope_internal(scope.clone()) + }); + dce.with_entry_points(self.inputs.keys().copied()) .set_preserve_callback(if self.allow_increase_termination { Arc::new(|_, _| PreserveNode::CanRemoveIgnoringChildren) } else { @@ -186,6 +208,11 @@ impl + 'static> ComposablePass for ConstantFoldPass { })?; Ok(()) } + + fn with_scope_internal(mut self, scope: impl Into) -> Self { + self.scope = Some(scope.into()); + self + } } /// Exhaustively apply constant folding to a HUGR. @@ -193,6 +220,7 @@ impl + 'static> ComposablePass for ConstantFoldPass { /// /// [`FuncDefn`]: hugr_core::ops::OpType::FuncDefn /// [`Module`]: hugr_core::ops::OpType::Module +#[deprecated(note = "Use ConstantFoldPass with a PassScope", since = "0.25.7")] pub fn constant_fold_pass + 'static>(mut h: impl AsMut) { let h = h.as_mut(); let c = ConstantFoldPass::default(); diff --git a/hugr-passes/src/const_fold/test.rs b/hugr-passes/src/const_fold/test.rs index e3a540575..44a954c6a 100644 --- a/hugr-passes/src/const_fold/test.rs +++ b/hugr-passes/src/const_fold/test.rs @@ -3,8 +3,8 @@ use std::{ sync::LazyLock, }; -use hugr_core::ops::Const; use hugr_core::ops::handle::NodeHandle; +use hugr_core::{Visibility, ops::Const}; use itertools::Itertools; use rstest::rstest; @@ -30,10 +30,20 @@ use hugr_core::std_extensions::logic::LogicOp; use hugr_core::types::{Signature, SumType, Type, TypeBound, TypeRow, TypeRowRV}; use hugr_core::{Hugr, HugrView, IncomingPort, Node, type_row}; -use crate::ComposablePass as _; -use crate::dataflow::{DFContext, PartialValue, partial_from_const}; +use crate::{ComposablePass as _, composable::ValidatingPass}; +use crate::{ + PassScope, + composable::WithScope, + dataflow::{DFContext, PartialValue, partial_from_const}, +}; + +use super::{ConstFoldContext, ConstantFoldPass, ValueHandle}; -use super::{ConstFoldContext, ConstantFoldPass, ValueHandle, constant_fold_pass}; +fn constant_fold_pass(h: &mut (impl HugrMut + 'static)) { + // the default ConstantFoldPass has no scope, i.e. preserving legacy behavior + let c = ConstantFoldPass::default().with_scope(PassScope::default()); + ValidatingPass::new(c).run(h).unwrap(); +} #[rstest] #[case(ConstInt::new_u(4, 2).unwrap(), true)] @@ -1592,9 +1602,10 @@ fn test_module() -> Result<(), Box> { let c17 = mb.add_constant(Value::from(ConstInt::new_u(5, 17)?)); let ad1 = mb.add_alias_declare("unused", TypeBound::Linear)?; let ad2 = mb.add_alias_def("unused2", INT_TYPES[3].clone())?; - let mut main = mb.define_function( + let mut main = mb.define_function_vis( "main", Signature::new(type_row![], vec![INT_TYPES[5].clone(); 2]), + Visibility::Public, )?; let lc7 = main.load_const(&c7); let lc17 = main.load_const(&c17); diff --git a/hugr-passes/src/dead_code.rs b/hugr-passes/src/dead_code.rs index 1d57f69cb..b91c4e8e5 100644 --- a/hugr-passes/src/dead_code.rs +++ b/hugr-passes/src/dead_code.rs @@ -6,7 +6,7 @@ use std::collections::{HashMap, HashSet, VecDeque}; use std::fmt::{Debug, Display, Formatter}; use std::sync::Arc; -use crate::ComposablePass; +use crate::{ComposablePass, PassScope}; /// Configuration for Dead Code Elimination pass #[derive(Clone)] @@ -14,6 +14,8 @@ pub struct DeadCodeElimPass { /// Nodes that are definitely needed - e.g. `FuncDefns`, but could be anything. /// Hugr Root is assumed to be an entry point even if not mentioned here. entry_points: Vec, + /// If None, use entrypoint-subtree (even if module root) + scope: Option, /// Callback identifying nodes that must be preserved even if their /// results are not used. Defaults to [`PreserveNode::default_for`]. preserve_callback: Arc>, @@ -23,6 +25,8 @@ impl Default for DeadCodeElimPass { fn default() -> Self { Self { entry_points: Default::default(), + // Preserve pre-PassScope behaviour of affecting entrypoint subtree only: + scope: None, preserve_callback: Arc::new(PreserveNode::default_for), } } @@ -36,11 +40,13 @@ impl Debug for DeadCodeElimPass { #[derive(Debug)] struct DCEDebug<'a, N> { entry_points: &'a Vec, + scope: &'a Option, } Debug::fmt( &DCEDebug { entry_points: &self.entry_points, + scope: &self.scope, }, f, ) @@ -97,11 +103,11 @@ impl DeadCodeElimPass { self } - /// Mark some nodes as entry points to the Hugr, i.e. so we cannot eliminate any code - /// used to evaluate these nodes. - /// [`HugrView::entrypoint`] is assumed to be an entry point; - /// for Module roots the client will want to mark some of the `FuncDefn` children - /// as entry points too. + /// Mark some nodes as starting points for analysis, i.e. so we cannot eliminate any code + /// used to evaluate these nodes. (E.g. nodes at which we may start executing the Hugr.) + /// + /// Other starting points are added according to the [PassScope]. + // TODO should we deprecate this? i.e. require use of PreserveCallback / Hugr edges? pub fn with_entry_points(mut self, entry_points: impl IntoIterator) -> Self { self.entry_points.extend(entry_points); self @@ -111,7 +117,11 @@ impl DeadCodeElimPass { let mut must_preserve = HashMap::new(); let mut needed = HashSet::new(); let mut q = VecDeque::from_iter(self.entry_points.iter().copied()); - q.push_front(h.entrypoint()); + + match &self.scope { + None => q.push_back(h.entrypoint()), + Some(scope) => q.extend(scope.preserve_interface(h)), + }; while let Some(n) = q.pop_front() { if !h.contains_node(n) { return Err(DeadCodeElimError::NodeNotFound(n)); @@ -119,6 +129,10 @@ impl DeadCodeElimPass { if !needed.insert(n) { continue; } + // Ensure no orphans, e.g. when preserving an entrypoint deep within a Hugr + // being globally optimized. We could remove more from parent, but would require transforming + // (e.g. removing individual Output ports) not just deleting, so don't. + q.extend(h.get_parent(n)); for (i, ch) in h.children(n).enumerate() { if self.must_preserve(h, &mut must_preserve, ch) || match h.get_optype(ch) { @@ -181,9 +195,16 @@ impl ComposablePass for DeadCodeElimPass { type Result = (); fn run(&self, hugr: &mut H) -> Result<(), Self::Error> { + let root = match &self.scope { + None => hugr.entrypoint(), + Some(scope) => match scope.root(hugr) { + Some(root) => root, + None => return Ok(()), + }, + }; let needed = self.find_needed_nodes(&*hugr)?; let remove = hugr - .entry_descendants() + .descendants(root) .filter(|n| !needed.contains(n)) .collect::>(); for n in remove { @@ -191,6 +212,11 @@ impl ComposablePass for DeadCodeElimPass { } Ok(()) } + + fn with_scope_internal(mut self, scope: impl Into) -> Self { + self.scope = Some(scope.into()); + self + } } #[cfg(test)] mod test { diff --git a/hugr-passes/src/dead_funcs.rs b/hugr-passes/src/dead_funcs.rs index a77c19fe9..2862e06e9 100644 --- a/hugr-passes/src/dead_funcs.rs +++ b/hugr-passes/src/dead_funcs.rs @@ -3,16 +3,17 @@ use std::collections::HashSet; use hugr_core::{ - HugrView, Node, + HugrView, Node, Visibility, hugr::hugrmut::HugrMut, module_graph::{ModuleGraph, StaticNode}, ops::{OpTag, OpTrait}, }; +use itertools::Either; use petgraph::visit::{Dfs, Walker}; use crate::{ - ComposablePass, - composable::{ValidatePassError, validate_if_test}, + ComposablePass, PassScope, + composable::{Preserve, ValidatePassError, validate_if_test}, }; #[derive(Debug, thiserror::Error)] @@ -48,23 +49,41 @@ fn reachable_funcs<'a, H: HugrView>( }) } -#[derive(Debug, Clone, Default)] +#[derive(Debug, Clone)] /// A configuration for the Dead Function Removal pass. pub struct RemoveDeadFuncsPass { - entry_points: Vec, + entry_points: Either, PassScope>, +} + +impl Default for RemoveDeadFuncsPass { + fn default() -> Self { + Self { + entry_points: Either::Left(Vec::new()), + } + } } impl RemoveDeadFuncsPass { + #[deprecated(note = "Use RemoveDeadFuncsPass::with_scope", since = "0.25.7")] /// Adds new entry points - these must be [`FuncDefn`] nodes /// that are children of the [`Module`] at the root of the Hugr. /// + /// Overrides any [PassScope] set by a call to [Self::with_scope_internal]. + /// /// [`FuncDefn`]: hugr_core::ops::OpType::FuncDefn /// [`Module`]: hugr_core::ops::OpType::Module pub fn with_module_entry_points( mut self, entry_points: impl IntoIterator, ) -> Self { - self.entry_points.extend(entry_points); + let v = match self.entry_points { + Either::Left(ref mut v) => v, + Either::Right(_) => { + self.entry_points = Either::Left(Vec::new()); + self.entry_points.as_mut().unwrap_left() + } + }; + v.extend(entry_points); self } } @@ -72,17 +91,50 @@ impl RemoveDeadFuncsPass { impl> ComposablePass for RemoveDeadFuncsPass { type Error = RemoveDeadFuncsError; type Result = (); + + /// Overrides any entrypoints set by a call to [Self::with_module_entry_points]. + fn with_scope_internal(mut self, scope: impl Into) -> Self { + self.entry_points = Either::Right(scope.into()); + self + } + fn run(&self, hugr: &mut H) -> Result<(), RemoveDeadFuncsError> { let mut entry_points = Vec::new(); - for &n in self.entry_points.iter() { - if !hugr.get_optype(n).is_func_defn() { - return Err(RemoveDeadFuncsError::InvalidEntryPoint { node: n }); + match &self.entry_points { + Either::Left(ep) => { + for &n in ep { + if !hugr.get_optype(n).is_func_defn() { + return Err(RemoveDeadFuncsError::InvalidEntryPoint { node: n }); + } + debug_assert_eq!(hugr.get_parent(n), Some(hugr.module_root())); + entry_points.push(n); + } + if hugr.entrypoint() != hugr.module_root() { + entry_points.push(hugr.entrypoint()) + } + } + Either::Right( + // If the entrypoint is the module root, not allowed to touch anything. + // Otherwise, we must keep the entrypoint (and can touch only inside it). + PassScope::EntrypointFlat | PassScope::EntrypointRecursive + // Optimize whole Hugr but keep all functions + | PassScope::Global(Preserve::All)) => { + return Ok(()); + } + Either::Right(PassScope::Global(Preserve::Entrypoint)) if hugr.entrypoint() != hugr.module_root() => { + entry_points.push(hugr.entrypoint()); + } + Either::Right(PassScope::Global(_)) => { + for n in hugr.children(hugr.module_root()) { + if hugr.get_optype(n).as_func_defn().is_some_and(|fd| fd.visibility() == &Visibility::Public) + { + entry_points.push(n); + } + } + if hugr.entrypoint() != hugr.module_root() { + entry_points.push(hugr.entrypoint()); + } } - debug_assert_eq!(hugr.get_parent(n), Some(hugr.module_root())); - entry_points.push(n); - } - if hugr.entrypoint() != hugr.module_root() { - entry_points.push(hugr.entrypoint()) } let mut reachable = @@ -109,8 +161,8 @@ impl> ComposablePass for RemoveDeadFuncsPass { } } -/// Deletes from the Hugr any functions that are not used by either [`Call`] or -/// [`LoadFunction`] nodes in reachable parts. +/// Deletes from the Hugr any functions that are not used by either `Call` or +/// `LoadFunction` nodes in reachable parts. /// /// `entry_points` may provide a list of entry points, which must be [`FuncDefn`]s (children of the root). /// The [HugrView::entrypoint] will also be used unless it is the [HugrView::module_root]. @@ -120,10 +172,10 @@ impl> ComposablePass for RemoveDeadFuncsPass { /// # Errors /// * If any node in `entry_points` is not a [`FuncDefn`] /// -/// [`Call`]: hugr_core::ops::OpType::Call /// [`FuncDefn`]: hugr_core::ops::OpType::FuncDefn -/// [`LoadFunction`]: hugr_core::ops::OpType::LoadFunction /// [`Module`]: hugr_core::ops::OpType::Module +#[deprecated(note = "Use RemoveDeadFuncsPass with a PassScope", since = "0.25.7")] +#[expect(deprecated)] pub fn remove_dead_funcs( h: &mut impl HugrMut, entry_points: impl IntoIterator, @@ -138,52 +190,70 @@ pub fn remove_dead_funcs( mod test { use std::collections::HashMap; + use hugr_core::builder::{Dataflow, DataflowSubContainer, HugrBuilder, ModuleBuilder}; + use hugr_core::hugr::hugrmut::HugrMut; use hugr_core::ops::handle::NodeHandle; + use hugr_core::{Hugr, Visibility}; + use hugr_core::{HugrView, extension::prelude::usize_t, types::Signature}; use itertools::Itertools; use rstest::rstest; - use hugr_core::builder::{Dataflow, DataflowSubContainer, HugrBuilder, ModuleBuilder}; - use hugr_core::hugr::hugrmut::HugrMut; - use hugr_core::{HugrView, extension::prelude::usize_t, types::Signature}; + use super::RemoveDeadFuncsPass; + use crate::PassScope; + use crate::composable::{Preserve, WithScope, test::run_validating}; + + fn hugr(use_entrypoint: bool) -> Hugr { + let mut hb = ModuleBuilder::new(); + let o2 = hb + .define_function("from_pub", Signature::new_endo(usize_t())) + .unwrap(); + let o2inp = o2.input_wires(); + let o2 = o2.finish_with_outputs(o2inp).unwrap(); + let mut o1 = hb + .define_function_vis( + "pubfunc", + Signature::new_endo(usize_t()), + Visibility::Public, + ) + .unwrap(); + + let o1c = o1.call(o2.handle(), &[], o1.input_wires()).unwrap(); + o1.finish_with_outputs(o1c.outputs()).unwrap(); - use super::remove_dead_funcs; + let fm = hb + .define_function("from_main", Signature::new_endo(usize_t())) + .unwrap(); + let f_inp = fm.input_wires(); + let fm = fm.finish_with_outputs(f_inp).unwrap(); + let mut m = hb + .define_function("main", Signature::new_endo(usize_t())) + .unwrap(); + let m_in = m.input_wires(); + let mut dfb = m.dfg_builder(Signature::new_endo(usize_t()), m_in).unwrap(); + let c = dfb.call(fm.handle(), &[], dfb.input_wires()).unwrap(); + let dfg = dfb.finish_with_outputs(c.outputs()).unwrap(); + m.finish_with_outputs(dfg.outputs()).unwrap(); + let mut h = hb.finish_hugr().unwrap(); + if use_entrypoint { + h.set_entrypoint(dfg.node()); + } + h + } #[rstest] #[case(false, [], vec![])] // No entry_points removes everything! #[case(true, [], vec!["from_main", "main"])] #[case(false, ["main"], vec!["from_main", "main"])] #[case(false, ["from_main"], vec!["from_main"])] - #[case(false, ["other1"], vec!["other1", "other2"])] - #[case(true, ["other2"], vec!["from_main", "main", "other2"])] - #[case(false, ["other1", "other2"], vec!["other1", "other2"])] + #[case(false, ["pubfunc"], vec!["from_pub", "pubfunc"])] + #[case(true, ["from_pub"], vec!["from_main", "from_pub", "main"])] + #[case(false, ["from_pub", "pubfunc"], vec!["from_pub", "pubfunc"])] fn remove_dead_funcs_entry_points( #[case] use_hugr_entrypoint: bool, #[case] entry_points: impl IntoIterator, #[case] retained_funcs: Vec<&'static str>, ) -> Result<(), Box> { - let mut hb = ModuleBuilder::new(); - let o2 = hb.define_function("other2", Signature::new_endo(usize_t()))?; - let o2inp = o2.input_wires(); - let o2 = o2.finish_with_outputs(o2inp)?; - let mut o1 = hb.define_function("other1", Signature::new_endo(usize_t()))?; - - let o1c = o1.call(o2.handle(), &[], o1.input_wires())?; - o1.finish_with_outputs(o1c.outputs())?; - - let fm = hb.define_function("from_main", Signature::new_endo(usize_t()))?; - let f_inp = fm.input_wires(); - let fm = fm.finish_with_outputs(f_inp)?; - let mut m = hb.define_function("main", Signature::new_endo(usize_t()))?; - let m_in = m.input_wires(); - let mut dfg = m.dfg_builder(Signature::new_endo(usize_t()), m_in)?; - let c = dfg.call(fm.handle(), &[], dfg.input_wires())?; - let dfg = dfg.finish_with_outputs(c.outputs()).unwrap(); - m.finish_with_outputs(dfg.outputs())?; - - let mut hugr = hb.finish_hugr()?; - if use_hugr_entrypoint { - hugr.set_entrypoint(dfg.node()); - } + let mut hugr = hugr(use_hugr_entrypoint); let avail_funcs = hugr .children(hugr.module_root()) @@ -194,7 +264,8 @@ mod test { }) .collect::>(); - remove_dead_funcs( + #[expect(deprecated)] + super::remove_dead_funcs( &mut hugr, entry_points .into_iter() @@ -215,4 +286,32 @@ mod test { assert_eq!(remaining_funcs, retained_funcs); Ok(()) } + + #[rstest] + #[case(Preserve::All, false, vec!["from_main", "from_pub", "main", "pubfunc"])] + #[case(PassScope::EntrypointFlat, true, vec!["from_main", "from_pub", "main", "pubfunc"])] + #[case(PassScope::EntrypointRecursive, false, vec!["from_main", "from_pub", "main", "pubfunc"])] + #[case(Preserve::Public, true, vec!["from_main", "from_pub", "main", "pubfunc"])] + #[case(Preserve::Public, false, vec!["from_pub", "pubfunc"])] + #[case(Preserve::Entrypoint, true, vec!["from_main", "main"])] + fn remove_dead_funcs_scope( + #[case] scope: impl Into, + #[case] use_entrypoint: bool, + #[case] retained_funcs: Vec<&'static str>, + ) { + let scope = scope.into(); + let mut hugr = hugr(use_entrypoint); + run_validating(RemoveDeadFuncsPass::default().with_scope(scope), &mut hugr).unwrap(); + + let remaining_funcs = hugr + .nodes() + .filter_map(|n| { + hugr.get_optype(n) + .as_func_defn() + .map(|fd| fd.func_name().as_str()) + }) + .sorted() + .collect_vec(); + assert_eq!(remaining_funcs, retained_funcs); + } } diff --git a/hugr-passes/src/lib.rs b/hugr-passes/src/lib.rs index d0841b86c..99a081a44 100644 --- a/hugr-passes/src/lib.rs +++ b/hugr-passes/src/lib.rs @@ -27,11 +27,17 @@ pub use composable::{ComposablePass, InScope, PassScope}; // Pass re-exports pub use dead_code::DeadCodeElimPass; -pub use dead_funcs::{RemoveDeadFuncsError, RemoveDeadFuncsPass, remove_dead_funcs}; +#[deprecated(note = "Use RemoveDeadFuncsPass instead", since = "0.25.7")] +#[expect(deprecated)] // Remove together +pub use dead_funcs::remove_dead_funcs; +pub use dead_funcs::{RemoveDeadFuncsError, RemoveDeadFuncsPass}; pub use force_order::{force_order, force_order_by_key}; pub use inline_funcs::inline_acyclic; pub use lower::{lower_ops, replace_many_ops}; -pub use monomorphize::{MonomorphizePass, mangle_name, monomorphize}; +#[deprecated(note = "Use MonomorphizePass instead", since = "0.25.7")] +#[expect(deprecated)] // Remove together +pub use monomorphize::monomorphize; +pub use monomorphize::{MonomorphizePass, mangle_name}; pub use non_local::{ensure_no_nonlocal_edges, nonlocal_edges}; pub use replace_types::ReplaceTypes; pub use untuple::UntuplePass; diff --git a/hugr-passes/src/monomorphize.rs b/hugr-passes/src/monomorphize.rs index 5ce486dde..9382a1d07 100644 --- a/hugr-passes/src/monomorphize.rs +++ b/hugr-passes/src/monomorphize.rs @@ -32,6 +32,7 @@ use crate::composable::{ValidatePassError, validate_if_test}; /// children of the root node. We make best effort to ensure that names (derived /// from parent function names and concrete type args) of new functions are unique /// whenever the names of their parents are unique, but this is not guaranteed. +#[deprecated(note = "Use MonomorphizePass instead", since = "0.25.7")] pub fn monomorphize( hugr: &mut impl HugrMut, ) -> Result<(), ValidatePassError> { @@ -281,13 +282,14 @@ mod test { HugrBuilder, ModuleBuilder, }; use hugr_core::extension::prelude::{ConstUsize, UnpackTuple, UnwrapBuilder, usize_t}; - use hugr_core::ops::handle::{FuncID, NodeHandle}; + use hugr_core::ops::handle::FuncID; use hugr_core::ops::{CallIndirect, DataflowOpTrait as _, FuncDefn, Tag}; use hugr_core::types::{PolyFuncType, Signature, Type, TypeArg, TypeBound, TypeEnum}; - use hugr_core::{Hugr, HugrView, Node}; + use hugr_core::{Hugr, HugrView, Node, Visibility}; use rstest::rstest; - use crate::{monomorphize, remove_dead_funcs}; + use crate::composable::{Preserve, WithScope}; + use crate::{ComposablePass, MonomorphizePass, RemoveDeadFuncsPass}; use super::{is_polymorphic, mangle_name}; @@ -306,7 +308,7 @@ mod test { let [i1] = dfg_builder.input_wires_arr(); let hugr = dfg_builder.finish_hugr_with_outputs([i1]).unwrap(); let mut hugr2 = hugr.clone(); - monomorphize(&mut hugr2).unwrap(); + MonomorphizePass.run(&mut hugr2).unwrap(); assert_eq!(hugr, hugr2); } @@ -349,9 +351,13 @@ mod test { let trip = fb.add_dataflow_op(tag, [elem1, elem2, elem])?; fb.finish_with_outputs(trip.outputs())? }; - let mn = { + { let outs = vec![triple_type(usize_t()), triple_type(pair_type(usize_t()))]; - let mut fb = mb.define_function("main", Signature::new(usize_t(), outs))?; + let mut fb = mb.define_function_vis( + "main", + Signature::new(usize_t(), outs), + Visibility::Public, + )?; let [elem] = fb.input_wires_arr(); let [res1] = fb .call(tr.handle(), &[usize_t().into()], [elem])? @@ -359,7 +365,7 @@ mod test { let pair = fb.call(db.handle(), &[usize_t().into()], [elem])?; let pty = pair_type(usize_t()).into(); let [res2] = fb.call(tr.handle(), &[pty], pair.outputs())?.outputs_arr(); - fb.finish_with_outputs([res1, res2])? + fb.finish_with_outputs([res1, res2])?; }; let mut hugr = mb.finish_hugr()?; assert_eq!( @@ -368,7 +374,7 @@ mod test { .count(), 3 ); - monomorphize(&mut hugr)?; + MonomorphizePass.run(&mut hugr)?; let mono = hugr; mono.validate()?; @@ -389,12 +395,15 @@ mod test { ["double", "main", "triple"] ); let mut mono2 = mono.clone(); - monomorphize(&mut mono2)?; + MonomorphizePass.run(&mut mono2)?; assert_eq!(mono2, mono); // Idempotent let mut nopoly = mono; - remove_dead_funcs(&mut nopoly, [mn.node()])?; + RemoveDeadFuncsPass::default() + .with_scope(Preserve::Public) + .run(&mut nopoly) + .unwrap(); let mut funcs = list_funcs(&nopoly); assert!(funcs.values().all(|(_, fd)| !is_polymorphic(fd))); @@ -542,7 +551,7 @@ mod test { let mut hugr = outer.finish_hugr_with_outputs([e1, e2]).unwrap(); hugr.set_entrypoint(hugr.module_root()); // We want to act on everything, not just `main` - monomorphize(&mut hugr).unwrap(); + MonomorphizePass.run(&mut hugr).unwrap(); let mono_hugr = hugr; mono_hugr.validate().unwrap(); let funcs = list_funcs(&mono_hugr); @@ -620,8 +629,11 @@ mod test { module_builder.finish_hugr().unwrap() }; - monomorphize(&mut hugr).unwrap(); - remove_dead_funcs(&mut hugr, []).unwrap(); + MonomorphizePass.run(&mut hugr).unwrap(); + RemoveDeadFuncsPass::default() + .with_scope(Preserve::Public) + .run(&mut hugr) + .unwrap(); let funcs = list_funcs(&hugr); assert!(funcs.values().all(|(_, fd)| !is_polymorphic(fd))); diff --git a/hugr-passes/src/normalize_cfgs.rs b/hugr-passes/src/normalize_cfgs.rs index 20d2494c0..1cfe1655d 100644 --- a/hugr-passes/src/normalize_cfgs.rs +++ b/hugr-passes/src/normalize_cfgs.rs @@ -9,7 +9,7 @@ use std::collections::HashMap; use hugr_core::extension::prelude::UnpackTuple; use hugr_core::hugr::hugrmut::HugrMut; use hugr_core::types::{EdgeKind, Signature, TypeRow}; -use itertools::Itertools; +use itertools::{Either, Itertools}; use hugr_core::hugr::patch::inline_dfg::InlineDFG; use hugr_core::hugr::patch::replace::{NewEdgeKind, NewEdgeSpec, Replacement}; @@ -18,7 +18,7 @@ use hugr_core::ops::{ }; use hugr_core::{Direction, Hugr, HugrView, Node, OutgoingPort, PortIndex}; -use crate::ComposablePass; +use crate::{ComposablePass, PassScope}; /// Merge any basic blocks that are direct children of the specified [`CFG`]-entrypoint /// Hugr. @@ -98,22 +98,34 @@ pub enum NormalizeCFGResult { /// A [ComposablePass] that normalizes CFGs (i.e. [normalize_cfg]) in a Hugr. #[derive(Clone, Debug)] pub struct NormalizeCFGPass { - cfgs: Vec, + scope: Either, PassScope>, } impl Default for NormalizeCFGPass { fn default() -> Self { - Self { cfgs: vec![] } + Self { + scope: Either::Left(vec![]), + } } } impl NormalizeCFGPass { /// Allows mutating the set of CFG nodes that will be normalized. /// + /// Note that calling this method (even if the returned mut-ref is not written to) will + /// override any previous call to [Self::with_scope_internal]. + /// /// If empty (the default), all (non-strict) descendants of the [HugrView::entrypoint] /// will be normalized. + #[deprecated(note = "Use with_scope", since = "0.25.7")] pub fn cfgs(&mut self) -> &mut Vec { - &mut self.cfgs + match &mut self.scope { + Either::Left(cfgs) => cfgs, + r => { + *r = Either::Left(Vec::new()); + r.as_mut().unwrap_left() + } + } } } @@ -124,17 +136,30 @@ impl ComposablePass for NormalizeCFGPass { type Result = HashMap>; fn run(&self, hugr: &mut H) -> Result { - let cfgs = if self.cfgs.is_empty() { - let mut v = hugr - .entry_descendants() - .filter(|n| hugr.get_optype(*n).is_cfg()) - .collect::>(); - // Process inner CFGs first, in case they are removed (if they are in a completely - // disconnected block when the Entry node has only the Exit as successor). - v.reverse(); - v - } else { - self.cfgs.clone() + let cfgs = match &self.scope { + Either::Left(cfgs) if !cfgs.is_empty() => cfgs.clone(), + _ => { + let ctrs = match &self.scope { + Either::Left(v) => { + assert!(v.is_empty()); + Either::Right(hugr.descendants(hugr.entrypoint())) + } + Either::Right(scope) => { + let r = scope.root(hugr); + if let Some(r) = r.filter(|_| scope.recursive()) { + Either::Right(hugr.descendants(r)) + } else { + Either::Left(r.into_iter()) + } + } + }; + let mut cfgs: Vec = + ctrs.filter(|n| hugr.get_optype(*n).is_cfg()).collect(); + // Process inner CFGs first, in case they are removed (if they are in a completely + // disconnected block when the Entry node has only the Exit as successor). + cfgs.reverse(); + cfgs + } }; let mut results = HashMap::new(); for cfg in cfgs { @@ -143,6 +168,12 @@ impl ComposablePass for NormalizeCFGPass { } Ok(results) } + + /// Overrides any previous call to [Self::cfgs] + fn with_scope_internal(mut self, scope: impl Into) -> Self { + self.scope = Either::Right(scope.into()); + self + } } /// Normalize a CFG in a Hugr: diff --git a/hugr-passes/src/replace_types.rs b/hugr-passes/src/replace_types.rs index e9dbdad7a..a4cf1a87c 100644 --- a/hugr-passes/src/replace_types.rs +++ b/hugr-passes/src/replace_types.rs @@ -9,6 +9,7 @@ use handlers::list_const; use hugr_core::hugr::linking::{HugrLinking, NameLinkingPolicy, OnMultiDefn}; use hugr_core::std_extensions::collections::array::array_type_def; use hugr_core::std_extensions::collections::list::list_type_def; +use itertools::Either; use thiserror::Error; use hugr_core::builder::{ @@ -28,7 +29,7 @@ use hugr_core::types::{ }; use hugr_core::{Direction, Hugr, HugrView, Node, PortIndex, Visibility, Wire}; -use crate::ComposablePass; +use crate::{ComposablePass, PassScope}; mod linearize; pub use linearize::{CallbackHandler, DelegatingLinearizer, LinearizeError, Linearizer}; @@ -334,8 +335,11 @@ impl ReplacementOptions { } } -/// A configuration of what types, ops, and constants should be replaced with what. -/// May be applied to a Hugr via [`Self::run`]. +/// A *lowering* [ComposablePass] that replaces types, ops and constants, i.e. changing +/// node signatures/interfaces. +/// +/// The struct configures what types, ops, and constants should be replaced with what, +/// and may be applied to a Hugr via [`Self::run`]. /// /// Parametrized types and ops will be reparameterized taking into account the /// replacements, but any ops taking/returning the replaced types *not* as a result of @@ -388,7 +392,7 @@ pub struct ReplaceTypes { ParametricType, Arc Result, ReplaceTypesError>>, >, - regions: Option>, + scope: Either>, } impl Default for ReplaceTypes { @@ -455,7 +459,9 @@ impl ReplaceTypes { param_ops: Default::default(), consts: Default::default(), param_consts: Default::default(), - regions: None, + // Not really clear what "preserve" means for a pass that changes signatures, + // but default to running on whole hugr not just entrypoint. + scope: Either::Left(PassScope::default()), } } @@ -703,9 +709,9 @@ impl ReplaceTypes { /// Set the regions of the Hugr to which this pass should be applied. /// /// If not set, the pass is applied to the whole Hugr. - /// Each call to overwrites any previous calls to `set_regions`. + /// Each call overwrites any previous calls to `set_regions` and/or [Self::with_scope_internal]. pub fn set_regions(&mut self, regions: impl IntoIterator) { - self.regions = Some(regions.into_iter().collect()); + self.scope = Either::Right(regions.into_iter().collect()); } fn process_subtree_opts( @@ -908,14 +914,25 @@ impl> ComposablePass for ReplaceTypes { type Error = ReplaceTypesError; type Result = bool; + /// Sets the scope within which the pass will operate. Note that this pass ignores + /// * [PassScope::preserve_interface], as this is a lowering pass: its purpose is to + /// change node signatures. + /// * [PassScope::recursive], as non-recursion generally leads to invalid Hugrs. + /// + /// Hence, really only the [PassScope::root] affects the pass. + fn with_scope_internal(mut self, scope: impl Into) -> Self { + self.scope = Either::Left(scope.into()); + self + } + fn run(&self, hugr: &mut H) -> Result { let temp: Vec; // keep alive - let regions = match self.regions { - Some(ref regs) => regs, - None => { - temp = vec![hugr.module_root()]; + let regions = match &self.scope { + Either::Left(scope) => { + temp = Vec::from_iter(scope.root(hugr)); &temp } + Either::Right(regs) => regs, }; let mut changed = false; for region_root in regions { diff --git a/hugr-passes/src/untuple.rs b/hugr-passes/src/untuple.rs index 725e878c4..eef4bc545 100644 --- a/hugr-passes/src/untuple.rs +++ b/hugr-passes/src/untuple.rs @@ -11,22 +11,30 @@ use hugr_core::hugr::views::sibling_subgraph::TopoConvexChecker; use hugr_core::ops::{OpTrait, OpType}; use hugr_core::types::Type; use hugr_core::{HugrView, Node, PortIndex, SimpleReplacement}; -use itertools::Itertools; +use itertools::{Either, Itertools}; -use crate::ComposablePass; +use crate::{ComposablePass, PassScope}; /// Configuration enum for the untuple rewrite pass. /// /// Indicates whether the pattern match should traverse the HUGR recursively. -#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +#[deprecated(note = "Use PassScope instead", since = "0.25.7")] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum UntupleRecursive { /// Traverse the HUGR recursively, i.e. consider the entire subtree Recursive, /// Do not traverse the HUGR recursively, i.e. consider only the sibling subgraph - #[default] NonRecursive, } +#[expect(deprecated)] // Remove along with UntupleRecursive +#[expect(clippy::derivable_impls)] // derive(Default) generates deprecation warning +impl Default for UntupleRecursive { + fn default() -> Self { + UntupleRecursive::NonRecursive + } +} + /// A pass that removes unnecessary `MakeTuple` operations immediately followed /// by `UnpackTuple`s. /// @@ -42,12 +50,29 @@ pub enum UntupleRecursive { /// /// Ignores pack/unpack nodes with order edges. // TODO: Supporting those requires updating the `SiblingSubgraph` implementation. See . -#[derive(Debug, Clone, Default)] +#[derive(Debug, Clone)] pub struct UntuplePass { - /// Whether to traverse the HUGR recursively. - recursive: UntupleRecursive, - /// Parent node under which to operate; None indicates the Hugr root - parent: Option, + /// Either a [PassScope] controlling which parts of the Hugr to process; + /// or a flag for recursiveness, and the parent node under which to operate + /// (None indicating the Hugr root) + #[expect(deprecated)] // remove Right half and just use PassScope + scope: Either)>, +} + +impl Default for UntuplePass { + fn default() -> Self { + #[expect(deprecated)] // Move to PassScope::Default() when UntupleRecursive is removed + Self { + scope: Either::Right((UntupleRecursive::default(), Option::default())), + } + } +} + +#[expect(deprecated)] // Remove along with UntupleRecursive +impl From for bool { + fn from(value: UntupleRecursive) -> Self { + value == UntupleRecursive::Recursive + } } #[derive(Debug, derive_more::Display, derive_more::Error, derive_more::From)] @@ -66,59 +91,122 @@ pub struct UntupleResult { } impl UntuplePass { - /// Create a new untuple pass with the given configuration. + /// Create a new untuple pass with the given recursiveness and that + /// will run on the entrypoint region/subtree. #[must_use] + #[deprecated( + note = "Use default() instead, followed by with_scope()", + since = "0.25.7" + )] + #[expect(deprecated)] // Remove along with UntupleRecursive pub fn new(recursive: UntupleRecursive) -> Self { Self { - recursive, - parent: None, + scope: Either::Right((recursive, None)), } } /// Sets the parent node to optimize (overwrites any previous setting) + /// + /// If the pass was previously configured by [Self::with_scope_internal] then + /// implicitly `[Self::set_recursive]`'s with [PassScope::recursive] + #[deprecated(note = "Use with_scope instead", since = "0.25.7")] + #[expect(deprecated)] // Remove along with UntupleRecursive pub fn set_parent(mut self, parent: impl Into>) -> Self { - self.parent = parent.into(); + match &mut self.scope { + Either::Left(p) => { + let rec = if p.recursive() { + UntupleRecursive::Recursive + } else { + UntupleRecursive::NonRecursive + }; + self.scope = Either::Right((rec, parent.into())) + } + Either::Right((_, p)) => *p = parent.into(), + }; self } /// Sets whether the pass should traverse the HUGR recursively. + /// + /// If the pass was last configured via [Self::with_scope_internal], overrides that, + /// with `set_parent` of default `None`. #[must_use] + #[deprecated(note = "Use with_scope", since = "0.25.7")] + #[expect(deprecated)] // Remove along with UntupleRecursive pub fn recursive(mut self, recursive: UntupleRecursive) -> Self { - self.recursive = recursive; + let parent = self.scope.right().and_then(|(_, p)| p); + self.scope = Either::Right((recursive, parent)); self } /// Find tuple pack operations followed by tuple unpack operations + /// beneath a specified parent and according to this instance's recursiveness + /// ([Self::recursive] or [Self::with_scope_internal] + [PassScope::recursive]) /// and generate rewrites to remove them. /// /// The returned rewrites are guaranteed to be independent of each other. /// /// Returns an iterator over the rewrites. + #[deprecated(note = "Use all_rewrites", since = "0.25.7")] pub fn find_rewrites( &self, hugr: &H, parent: H::Node, ) -> Vec> { - let mut res = Vec::new(); - let mut children_queue = VecDeque::new(); - children_queue.push_back(parent); - - // Required to create SimpleReplacements. - let mut convex_checker: Option> = None; - - while let Some(parent) = children_queue.pop_front() { - for node in hugr.children(parent) { - let op = hugr.get_optype(node); - if let Some(rw) = make_rewrite(hugr, &mut convex_checker, node, op) { - res.push(rw); - } - if self.recursive == UntupleRecursive::Recursive && op.is_container() { - children_queue.push_back(node); - } + let recursive = match &self.scope { + Either::Left(scope) => scope.recursive(), + Either::Right((rec, _)) => (*rec).into(), + }; + find_rewrites(hugr, parent, recursive) + } + + /// Find tuple pack operations followed by tuple unpack operations + /// and generate rewrites to remove them. + /// + /// The returned rewrites are guaranteed to be independent of each other. + /// + /// Returns an iterator over the rewrites. + pub fn all_rewrites>( + &self, + hugr: &H, + ) -> Vec> { + let (recursive, parent) = match &self.scope { + Either::Left(scope) => { + let Some(root) = scope.root(hugr) else { + return vec![]; + }; + (scope.recursive(), root) + } + Either::Right((rec, parent)) => ((*rec).into(), parent.unwrap_or(hugr.entrypoint())), + }; + find_rewrites(hugr, parent, recursive) + } +} + +fn find_rewrites( + hugr: &H, + parent: H::Node, + recursive: bool, +) -> Vec> { + let mut res = Vec::new(); + let mut children_queue = VecDeque::new(); + children_queue.push_back(parent); + + // Required to create SimpleReplacements. + let mut convex_checker: Option> = None; + + while let Some(parent) = children_queue.pop_front() { + for node in hugr.children(parent) { + let op = hugr.get_optype(node); + if let Some(rw) = make_rewrite(hugr, &mut convex_checker, node, op) { + res.push(rw); + } + if recursive && op.is_container() { + children_queue.push_back(node); } } - res } + res } impl> ComposablePass for UntuplePass { @@ -126,7 +214,7 @@ impl> ComposablePass for UntuplePass { type Result = UntupleResult; fn run(&self, hugr: &mut H) -> Result { - let rewrites = self.find_rewrites(hugr, self.parent.unwrap_or(hugr.entrypoint())); + let rewrites = self.all_rewrites(hugr); let rewrites_applied = rewrites.len(); // The rewrites are independent, so we can always apply them all. for rewrite in rewrites { @@ -134,6 +222,12 @@ impl> ComposablePass for UntuplePass { } Ok(UntupleResult { rewrites_applied }) } + + /// Overrides any [Self::set_parent] or [Self::recursive] + fn with_scope_internal(mut self, scope: impl Into) -> Self { + self.scope = Either::Left(scope.into()); + self + } } /// Returns true if the given optype is a `MakeTuple` operation. @@ -278,10 +372,10 @@ fn remove_pack_unpack<'h, T: HugrView>( #[cfg(test)] mod test { use super::*; + use crate::composable::WithScope; + use hugr_core::Hugr; use hugr_core::builder::FunctionBuilder; use hugr_core::extension::prelude::{UnpackTuple, bool_t, qb_t}; - - use hugr_core::Hugr; use hugr_core::ops::handle::NodeHandle; use hugr_core::std_extensions::arithmetic::float_types::float64_type; use hugr_core::types::Signature; @@ -475,14 +569,16 @@ mod test { #[case] mut hugr: Hugr, #[case] expected_rewrites: usize, #[case] remaining_nodes: usize, + #[values(true, false)] use_scope: bool, ) { - let pass = UntuplePass::default().recursive(UntupleRecursive::NonRecursive); - let parent = hugr.entrypoint(); - let res = pass - .set_parent(parent) - .run(&mut hugr) - .unwrap_or_else(|e| panic!("{e}")); + let pass = if use_scope { + UntuplePass::default().with_scope(PassScope::EntrypointFlat) + } else { + #[expect(deprecated)] // Remove use_scope==false case along with UntupleRecursive + UntuplePass::default().set_parent(parent) + }; + let res = pass.run(&mut hugr).unwrap_or_else(|e| panic!("{e}")); assert_eq!(res.rewrites_applied, expected_rewrites); assert_eq!(hugr.children(parent).count(), remaining_nodes); }