Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 36 additions & 76 deletions hugr-core/src/hugr/views/sibling_subgraph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
//! hierarchy, i.e. within a sibling graph. Convex subgraph are always
//! induced subgraphs, i.e. they are defined by a subset of the sibling nodes.

use std::cell::OnceCell;
use std::collections::HashSet;
use std::mem;

Expand Down Expand Up @@ -228,7 +227,7 @@ impl<N: HugrNode> SiblingSubgraph<N> {

/// Create a new convex sibling subgraph from input and output boundaries.
///
/// Provide a [`ConvexChecker`] instance to avoid constructing one for
/// Provide a [`TopoConvexChecker`] instance to avoid constructing one for
/// faster convexity check. If you do not have one, use
/// [`SiblingSubgraph::try_new`].
///
Expand Down Expand Up @@ -296,7 +295,7 @@ impl<N: HugrNode> SiblingSubgraph<N> {

/// Create a subgraph from a set of nodes.
///
/// Provide a [`ConvexChecker`] instance to avoid constructing one for
/// Provide a [`TopoConvexChecker`] instance to avoid constructing one for
/// faster convexity check. If you do not have one, use
/// [`SiblingSubgraph::try_from_nodes`].
///
Expand Down Expand Up @@ -357,7 +356,7 @@ impl<N: HugrNode> SiblingSubgraph<N> {
intervals: &LineIntervals,
line_checker: &LineConvexChecker<impl HugrView<Node = N>>,
) -> Result<Self, InvalidSubgraph<N>> {
if !line_checker.get_checker().is_convex_by_intervals(intervals) {
if !line_checker.checker.is_convex_by_intervals(intervals) {
return Err(InvalidSubgraph::NotConvex);
}

Expand Down Expand Up @@ -494,14 +493,14 @@ impl<N: HugrNode> SiblingSubgraph<N> {
let checker_ref = match mode {
ValidationMode::WithChecker(c) => Some(c),
ValidationMode::CheckConvexity => {
checker = ConvexChecker::new(hugr, self.get_parent(hugr));
checker = TopoConvexChecker::new(hugr, self.get_parent(hugr));
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ValidationMode::WithChecker specifically contains a TopoConvexChecker so this must match that

Some(&checker)
}
ValidationMode::SkipConvexity => None,
};
if let Some(checker) = checker_ref {
let (subpg, _) = make_pg_subgraph(hugr, &self.inputs, &self.outputs);
if !subpg.is_convex_with_checker(&checker.init_checker().0) {
if !subpg.is_convex_with_checker(&checker.checker) {
return Err(InvalidSubgraph::NotConvex);
}
}
Expand Down Expand Up @@ -983,33 +982,33 @@ pub type LineConvexChecker<'g, Base> =
/// This type is generic over the convexity checker used. If checking convexity
/// for circuit-like graphs, use [`LineConvexChecker`], otherwise use
/// [`TopoConvexChecker`].
#[derive(Clone)]
pub struct ConvexChecker<'g, Base: HugrView, Checker> {
/// The base HUGR to check convexity on.
base: &'g Base,
/// The parent of the region where we are checking convexity.
#[allow(unused)] // Useful for debugging
region_parent: Base::Node,
/// A lazily initialized convexity checker, along with a map from nodes in
/// the region to `Base` nodes.
checker: OnceCell<(Checker, Base::RegionPortgraphNodes)>,
/// A convexity checker initialized for the nodes in that region
checker: Checker,
/// a map from nodes in the region to `Base` nodes.
node_map: Base::RegionPortgraphNodes,
}

impl<'g, Base: HugrView, Checker: Clone> Clone for ConvexChecker<'g, Base, Checker> {
fn clone(&self) -> Self {
Self {
base: self.base,
region_parent: self.region_parent,
checker: self.checker.clone(),
}
}
}

impl<'g, Base: HugrView, Checker> ConvexChecker<'g, Base, Checker> {
impl<'g, Base, Checker> ConvexChecker<'g, Base, Checker>
where
Base: HugrView,
Checker: CreateConvexChecker<CheckerRegion<'g, Base>>,
{
/// Create a new convexity checker.
pub fn new(base: &'g Base, region_parent: Base::Node) -> Self {
let (region, node_map) = base.region_portgraph(region_parent);
let checker = Checker::new_convex_checker(region);
Self {
base,
region_parent,
checker: OnceCell::new(),
checker,
node_map,
}
}

Expand All @@ -1026,32 +1025,6 @@ impl<'g, Base: HugrView, Checker> ConvexChecker<'g, Base, Checker> {
}
}

impl<'g, Base, Checker> ConvexChecker<'g, Base, Checker>
where
Base: HugrView,
Checker: CreateConvexChecker<CheckerRegion<'g, Base>>,
{
/// Returns the portgraph convexity checker, initializing it if necessary.
fn init_checker(&self) -> &(Checker, Base::RegionPortgraphNodes) {
self.checker.get_or_init(|| {
let (region, node_map) = self.base.region_portgraph(self.region_parent);
let checker = Checker::new_convex_checker(region);
(checker, node_map)
})
}

/// Returns the node map from the region to the base HUGR.
#[expect(dead_code)]
fn get_node_map(&self) -> &Base::RegionPortgraphNodes {
&self.init_checker().1
}

/// Returns the portgraph convexity checker, initializing it if necessary.
fn get_checker(&self) -> &Checker {
&self.init_checker().0
}
}

impl<'g, Base, Checker> portgraph::algorithms::ConvexChecker for ConvexChecker<'g, Base, Checker>
where
Base: HugrView,
Expand All @@ -1069,7 +1042,7 @@ where
if nodes.peek().is_none() || nodes.peek().is_none() {
return true;
}
self.get_checker().is_convex(nodes, inputs, outputs)
self.checker.is_convex(nodes, inputs, outputs)
}
}

Expand All @@ -1079,12 +1052,11 @@ impl<'g, Base: HugrView> LineConvexChecker<'g, Base> {
&self,
nodes: impl IntoIterator<Item = Base::Node>,
) -> Option<LineIntervals> {
let (checker, node_map) = self.init_checker();
let nodes = nodes
.into_iter()
.map(|n| node_map.to_portgraph(n))
.map(|n| self.node_map.to_portgraph(n))
.collect_vec();
checker.get_intervals_from_nodes(nodes)
self.checker.get_intervals_from_nodes(nodes)
}

/// Return the line intervals defined by the given boundary ports in the
Expand All @@ -1096,39 +1068,37 @@ impl<'g, Base: HugrView> LineConvexChecker<'g, Base> {
&self,
ports: impl IntoIterator<Item = (Base::Node, Port)>,
) -> Option<LineIntervals> {
let (checker, node_map) = self.init_checker();
let ports = ports
.into_iter()
.map(|(n, p)| {
let node = node_map.to_portgraph(n);
checker
let node = self.node_map.to_portgraph(n);
self.checker
.graph()
.port_index(node, p.pg_offset())
.expect("valid port")
})
.collect_vec();
checker.get_intervals_from_boundary_ports(ports)
self.checker.get_intervals_from_boundary_ports(ports)
}

/// Return the nodes that are within the given line intervals.
pub fn nodes_in_intervals<'a>(
&'a self,
intervals: &'a LineIntervals,
) -> impl Iterator<Item = Base::Node> + 'a {
let (checker, node_map) = self.init_checker();
checker
self.checker
.nodes_in_intervals(intervals)
.map(|pg_node| node_map.from_portgraph(pg_node))
.map(|pg_node| self.node_map.from_portgraph(pg_node))
}

/// Get the lines passing through the given port.
pub fn lines_at_port(&self, node: Base::Node, port: impl Into<Port>) -> &[LineIndex] {
let (checker, node_map) = self.init_checker();
let port = checker
let port = self
.checker
.graph()
.port_index(node_map.to_portgraph(node), port.into().pg_offset())
.port_index(self.node_map.to_portgraph(node), port.into().pg_offset())
.expect("valid port");
checker.lines_at_port(port)
self.checker.lines_at_port(port)
}

/// Extend the given intervals to include the given node.
Expand All @@ -1139,16 +1109,14 @@ impl<'g, Base: HugrView> LineConvexChecker<'g, Base> {
///
/// If `false` is returned, the `intervals` are left unchanged.
pub fn try_extend_intervals(&self, intervals: &mut LineIntervals, node: Base::Node) -> bool {
let (checker, node_map) = self.init_checker();
let node = node_map.to_portgraph(node);
checker.try_extend_intervals(intervals, node)
let node = self.node_map.to_portgraph(node);
self.checker.try_extend_intervals(intervals, node)
}

/// Get the position of a node on its lines.
pub fn get_position(&self, node: Base::Node) -> Position {
let (checker, node_map) = self.init_checker();
let node = node_map.to_portgraph(node);
checker.get_position(node)
let node = self.node_map.to_portgraph(node);
self.checker.get_position(node)
}
}

Expand Down Expand Up @@ -1379,14 +1347,6 @@ fn is_non_value_edge<H: HugrView>(hugr: &H, node: H::Node, port: Port) -> bool {
is_other || is_static
}

impl<'a, 'c, G: HugrView, Checker: Clone> From<&'a ConvexChecker<'c, G, Checker>>
for std::borrow::Cow<'a, ConvexChecker<'c, G, Checker>>
{
fn from(value: &'a ConvexChecker<'c, G, Checker>) -> Self {
Self::Borrowed(value)
}
}

/// Errors that can occur while constructing a [`SimpleReplacement`].
#[derive(Debug, Clone, PartialEq, Error)]
#[non_exhaustive]
Expand Down
Loading