Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 0 additions & 12 deletions hugr-passes/src/composable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -315,18 +315,6 @@ where
}
}

// Note remove when deprecated constant_fold_pass / remove_dead_funcs are removed
pub(crate) fn validate_if_test<P: ComposablePass<H>, H: HugrMut>(
pass: P,
hugr: &mut H,
) -> Result<P::Result, ValidatePassError<H::Node, P::Error>> {
if cfg!(test) {
ValidatingPass::new(pass).run(hugr)
} else {
Ok(pass.run(hugr)?)
}
}

#[cfg(test)]
pub(crate) mod test {
use hugr_core::ops::Value;
Expand Down
41 changes: 6 additions & 35 deletions hugr-passes/src/const_fold.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use hugr_core::{
};
use value_handle::ValueHandle;

use crate::composable::{ComposablePass, PassScope, WithScope, validate_if_test};
use crate::composable::{ComposablePass, PassScope, WithScope};
use crate::dataflow::{
ConstLoader, ConstLocation, DFContext, Machine, PartialValue, TailLoopTermination,
partial_from_const,
Expand All @@ -26,7 +26,7 @@ use crate::dead_code::{DeadCodeElimError, DeadCodeElimPass, PreserveNode};
/// A configuration for the Constant Folding pass.
pub struct ConstantFoldPass {
allow_increase_termination: bool,
scope: Option<PassScope>,
scope: PassScope,
/// Each outer key Node must be either:
/// - a `FuncDefn` child of the root, if the root is a module; or
/// - the entrypoint, if the entrypoint is not a Module
Expand Down Expand Up @@ -101,11 +101,7 @@ impl<H: HugrMut<Node = Node> + 'static> ComposablePass<H> for ConstantFoldPass {
/// [`ConstFoldError::InvalidEntryPoint`] if an entry-point added by [`Self::with_inputs`]
/// was of an invalid [`OpType`]
fn run(&self, hugr: &mut H) -> Result<(), ConstFoldError> {
let Some(root) = self
.scope
.as_ref()
.map_or(Some(hugr.entrypoint()), |sc| sc.root(hugr))
else {
let Some(root) = self.scope.root(hugr) else {
return Ok(()); // Scope says do nothing
};
let fresh_node = Node::from(portgraph::NodeIndex::new(
Expand All @@ -130,7 +126,7 @@ impl<H: HugrMut<Node = Node> + 'static> ComposablePass<H> for ConstantFoldPass {
.map_err(|op| ConstFoldError::InvalidEntryPoint { node: n, op })?;
}

for node in self.scope.iter().flat_map(|sc| sc.preserve_interface(hugr)) {
for node in self.scope.preserve_interface(hugr) {
if node == hugr.module_root() || self.inputs.contains_key(&node) {
// Cannot prepopulate inputs for module-root; do not `join` with inputs explicitly specified.
continue;
Expand Down Expand Up @@ -182,12 +178,7 @@ impl<H: HugrMut<Node = Node> + 'static> ComposablePass<H> for ConstantFoldPass {
hugr.connect(lcst, OutgoingPort::from(0), n, inport);
}
// Eliminate dead code not required for the same entry points.
let dce = self
.scope
.as_ref()
.map_or(DeadCodeElimPass::<H>::default(), |scope| {
DeadCodeElimPass::<H>::default_with_scope(scope.clone())
});
let dce = DeadCodeElimPass::<H>::default_with_scope(self.scope.clone());
dce.with_entry_points(self.inputs.keys().copied())
.set_preserve_callback(if self.allow_increase_termination {
Arc::new(|_, _| PreserveNode::CanRemoveIgnoringChildren)
Expand All @@ -212,31 +203,11 @@ impl<H: HugrMut<Node = Node> + 'static> ComposablePass<H> for ConstantFoldPass {

impl WithScope for ConstantFoldPass {
fn with_scope(mut self, scope: impl Into<PassScope>) -> Self {
self.scope = Some(scope.into());
self.scope = scope.into();
self
}
}

/// Exhaustively apply constant folding to a HUGR.
/// If the Hugr's entrypoint is its [`Module`], assumes all [`FuncDefn`] children are reachable.
///
/// [`FuncDefn`]: hugr_core::ops::OpType::FuncDefn
/// [`Module`]: hugr_core::ops::OpType::Module
#[deprecated(note = "Use ConstantFoldPass with a PassScope", since = "0.25.7")]
pub fn constant_fold_pass<H: HugrMut<Node = Node> + 'static>(mut h: impl AsMut<H>) {
let h = h.as_mut();
let c = ConstantFoldPass::default();
let c = if h.get_optype(h.entrypoint()).is_module() {
let no_inputs: [(IncomingPort, _); 0] = [];
h.children(h.entrypoint())
.filter(|n| h.get_optype(*n).is_func_defn())
.fold(c, |c, n| c.with_inputs(n, no_inputs.iter().cloned()))
} else {
c
};
validate_if_test(c, h).unwrap();
}

struct ConstFoldContext;

impl ConstLoader<ValueHandle<Node>> for ConstFoldContext {
Expand Down
149 changes: 13 additions & 136 deletions hugr-passes/src/dead_funcs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,10 @@ use hugr_core::{
module_graph::{ModuleGraph, StaticNode},
ops::{OpTag, OpTrait},
};
use itertools::Either;
use petgraph::visit::{Dfs, Walker};

use crate::composable::{
ComposablePass, PassScope, Preserve, ValidatePassError, WithScope, validate_if_test,
};
use crate::composable::{Preserve, WithScope};
use crate::{ComposablePass, PassScope};

#[derive(Debug, thiserror::Error)]
#[non_exhaustive]
Expand Down Expand Up @@ -48,76 +46,28 @@ fn reachable_funcs<'a, H: HugrView>(
})
}

#[derive(Debug, Clone)]
#[derive(Debug, Clone, Default)]
/// A configuration for the Dead Function Removal pass.
pub struct RemoveDeadFuncsPass {
entry_points: Either<Vec<Node>, PassScope>,
}

impl Default for RemoveDeadFuncsPass {
fn default() -> Self {
Self {
entry_points: Either::Left(Vec::new()),
}
}
entry_points: PassScope,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we rename this to scope?

}

impl RemoveDeadFuncsPass {
#[deprecated(note = "Use RemoveDeadFuncsPass::with_scope", since = "0.25.7")]
/// Adds new entry points - these must be [`FuncDefn`] nodes
/// that are children of the [`Module`] at the root of the Hugr.
///
/// Overrides any [PassScope] set by a call to [Self::with_scope].
///
/// [`FuncDefn`]: hugr_core::ops::OpType::FuncDefn
/// [`Module`]: hugr_core::ops::OpType::Module
pub fn with_module_entry_points(
mut self,
entry_points: impl IntoIterator<Item = Node>,
) -> Self {
let v = match self.entry_points {
Either::Left(ref mut v) => v,
Either::Right(_) => {
self.entry_points = Either::Left(Vec::new());
self.entry_points.as_mut().unwrap_left()
}
};
v.extend(entry_points);
self
}
}

impl<H: HugrMut<Node = Node>> ComposablePass<H> for RemoveDeadFuncsPass {
impl<H: HugrMut> ComposablePass<H> for RemoveDeadFuncsPass {
type Error = RemoveDeadFuncsError;
type Result = ();

fn run(&self, hugr: &mut H) -> Result<(), RemoveDeadFuncsError> {
let mut entry_points = Vec::new();
match &self.entry_points {
Either::Left(ep) => {
for &n in ep {
if !hugr.get_optype(n).is_func_defn() {
return Err(RemoveDeadFuncsError::InvalidEntryPoint { node: n });
}
debug_assert_eq!(hugr.get_parent(n), Some(hugr.module_root()));
entry_points.push(n);
}
if hugr.entrypoint() != hugr.module_root() {
entry_points.push(hugr.entrypoint())
}
}
Either::Right(
// If the entrypoint is the module root, not allowed to touch anything.
// Otherwise, we must keep the entrypoint (and can touch only inside it).
PassScope::EntrypointFlat | PassScope::EntrypointRecursive
// Optimize whole Hugr but keep all functions
| PassScope::Global(Preserve::All)) => {
return Ok(());
}
Either::Right(PassScope::Global(Preserve::Entrypoint)) if hugr.entrypoint() != hugr.module_root() => {
// If the entrypoint is the module root, not allowed to touch anything.
// Otherwise, we must keep the entrypoint (and can touch only inside it).
PassScope::EntrypointFlat | PassScope::EntrypointRecursive
// Optimize whole Hugr but keep all functions
| PassScope::Global(Preserve::All) => return Ok(()),
PassScope::Global(Preserve::Entrypoint) if hugr.entrypoint() != hugr.module_root() => {
entry_points.push(hugr.entrypoint());
}
Either::Right(PassScope::Global(_)) => {
PassScope::Global(_) => {
for n in hugr.children(hugr.module_root()) {
if hugr.get_optype(n).as_func_defn().is_some_and(|fd| fd.visibility() == &Visibility::Public)
{
Expand Down Expand Up @@ -156,39 +106,13 @@ impl<H: HugrMut<Node = Node>> ComposablePass<H> for RemoveDeadFuncsPass {

impl WithScope for RemoveDeadFuncsPass {
fn with_scope(mut self, scope: impl Into<PassScope>) -> Self {
self.entry_points = Either::Right(scope.into());
self.entry_points = scope.into();
self
}
}

/// Deletes from the Hugr any functions that are not used by either `Call` or
/// `LoadFunction` nodes in reachable parts.
///
/// `entry_points` may provide a list of entry points, which must be [`FuncDefn`]s (children of the root).
/// The [HugrView::entrypoint] will also be used unless it is the [HugrView::module_root].
/// Note that for a [`Module`]-rooted Hugr with no `entry_points` provided, this will remove
/// all functions from the module.
///
/// # Errors
/// * If any node in `entry_points` is not a [`FuncDefn`]
///
/// [`FuncDefn`]: hugr_core::ops::OpType::FuncDefn
/// [`Module`]: hugr_core::ops::OpType::Module
#[deprecated(note = "Use RemoveDeadFuncsPass with a PassScope", since = "0.25.7")]
#[expect(deprecated)]
pub fn remove_dead_funcs(
h: &mut impl HugrMut<Node = Node>,
entry_points: impl IntoIterator<Item = Node>,
) -> Result<(), ValidatePassError<Node, RemoveDeadFuncsError>> {
validate_if_test(
RemoveDeadFuncsPass::default().with_module_entry_points(entry_points),
h,
)
}

#[cfg(test)]
mod test {
use std::collections::HashMap;

use hugr_core::builder::{Dataflow, DataflowSubContainer, HugrBuilder, ModuleBuilder};
use hugr_core::hugr::hugrmut::HugrMut;
Expand Down Expand Up @@ -240,53 +164,6 @@ mod test {
h
}

#[rstest]
#[case(false, [], vec![])] // No entry_points removes everything!
#[case(true, [], vec!["from_main", "main"])]
#[case(false, ["main"], vec!["from_main", "main"])]
#[case(false, ["from_main"], vec!["from_main"])]
#[case(false, ["pubfunc"], vec!["from_pub", "pubfunc"])]
#[case(true, ["from_pub"], vec!["from_main", "from_pub", "main"])]
#[case(false, ["from_pub", "pubfunc"], vec!["from_pub", "pubfunc"])]
fn remove_dead_funcs_entry_points(
#[case] use_hugr_entrypoint: bool,
#[case] entry_points: impl IntoIterator<Item = &'static str>,
#[case] retained_funcs: Vec<&'static str>,
) -> Result<(), Box<dyn std::error::Error>> {
let mut hugr = hugr(use_hugr_entrypoint);

let avail_funcs = hugr
.children(hugr.module_root())
.filter_map(|n| {
hugr.get_optype(n)
.as_func_defn()
.map(|fd| (fd.func_name().clone(), n))
})
.collect::<HashMap<_, _>>();

#[expect(deprecated)]
super::remove_dead_funcs(
&mut hugr,
entry_points
.into_iter()
.map(|name| *avail_funcs.get(name).unwrap())
.collect::<Vec<_>>(),
)
.unwrap();

let remaining_funcs = hugr
.nodes()
.filter_map(|n| {
hugr.get_optype(n)
.as_func_defn()
.map(|fd| fd.func_name().as_str())
})
.sorted()
.collect_vec();
assert_eq!(remaining_funcs, retained_funcs);
Ok(())
}

#[rstest]
#[case(Preserve::All, false, vec!["from_main", "from_pub", "main", "pubfunc"])]
#[case(PassScope::EntrypointFlat, true, vec!["from_main", "from_pub", "main", "pubfunc"])]
Expand Down
6 changes: 0 additions & 6 deletions hugr-passes/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,10 @@ pub use composable::{ComposablePass, InScope, PassScope};

// Pass re-exports
pub use dead_code::DeadCodeElimPass;
#[deprecated(note = "Use RemoveDeadFuncsPass instead", since = "0.25.7")]
#[expect(deprecated)] // Remove together
pub use dead_funcs::remove_dead_funcs;
pub use dead_funcs::{RemoveDeadFuncsError, RemoveDeadFuncsPass};
pub use force_order::{force_order, force_order_by_key};
pub use inline_funcs::inline_acyclic;
pub use lower::{lower_ops, replace_many_ops};
#[deprecated(note = "Use MonomorphizePass instead", since = "0.25.7")]
#[expect(deprecated)] // Remove together
pub use monomorphize::monomorphize;
pub use monomorphize::{MonomorphizePass, mangle_name};
#[deprecated(
note = "Use LocalizeEdgesPass::check_no_nonlocal_edges",
Expand Down
25 changes: 1 addition & 24 deletions hugr-passes/src/monomorphize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,32 +13,9 @@ use hugr_core::{
use hugr_core::hugr::{HugrView, OpType, hugrmut::HugrMut};
use itertools::Itertools as _;

use crate::composable::{ValidatePassError, WithScope, validate_if_test};
use crate::composable::WithScope;
use crate::{ComposablePass, PassScope};

/// Replaces calls to polymorphic functions with calls to new monomorphic
/// instantiations of the polymorphic ones.
///
/// If the Hugr is [Module](OpType::Module)-rooted,
/// * then the original polymorphic [`FuncDefn`]s are left untouched (including Calls inside them)
/// - [`crate::remove_dead_funcs`] can be used when no other Hugr will be linked in that might instantiate these
/// * else, the originals are removed (they are invisible from outside the Hugr); however, note
/// that this behaviour is expected to change in a future release to match Module-rooted Hugrs.
///
/// If the Hugr is [`FuncDefn`](OpType::FuncDefn)-rooted with polymorphic
/// signature then the HUGR will not be modified.
///
/// Monomorphic copies of polymorphic functions will be added to the HUGR as
/// children of the root node. We make best effort to ensure that names (derived
/// from parent function names and concrete type args) of new functions are unique
/// whenever the names of their parents are unique, but this is not guaranteed.
#[deprecated(note = "Use MonomorphizePass instead", since = "0.25.7")]
pub fn monomorphize(
hugr: &mut impl HugrMut<Node = Node>,
) -> Result<(), ValidatePassError<Node, Infallible>> {
validate_if_test(MonomorphizePass::default(), hugr)
}

fn is_polymorphic(fd: &FuncDefn) -> bool {
!fd.signature().params().is_empty()
}
Expand Down
Loading
Loading