From c328a181cc9880c6986eb51cf95fbdd6e78b681c Mon Sep 17 00:00:00 2001 From: Douglas Wilson Date: Thu, 6 Feb 2025 14:32:43 +0000 Subject: [PATCH 01/23] wip --- hugr-core/src/hugr/hugrmut.rs | 45 +++++- hugr-core/src/hugr/internal.rs | 17 ++- hugr-core/src/types.rs | 20 ++- hugr-passes/src/non_local.rs | 241 ++++++++++++++++++++++++++++++++- 4 files changed, 316 insertions(+), 7 deletions(-) diff --git a/hugr-core/src/hugr/hugrmut.rs b/hugr-core/src/hugr/hugrmut.rs index 4056f36e61..b76d1897fe 100644 --- a/hugr-core/src/hugr/hugrmut.rs +++ b/hugr-core/src/hugr/hugrmut.rs @@ -4,6 +4,7 @@ use core::panic; use std::collections::HashMap; use std::sync::Arc; +use itertools::Itertools as _; use portgraph::view::{NodeFilter, NodeFiltered}; use portgraph::{LinkMut, NodeIndex, PortMut, PortView, SecondaryMap}; @@ -11,7 +12,7 @@ use crate::extension::ExtensionRegistry; use crate::hugr::views::SiblingSubgraph; use crate::hugr::{HugrView, Node, OpType, RootTagged}; use crate::hugr::{NodeMetadata, Rewrite}; -use crate::{Extension, Hugr, IncomingPort, OutgoingPort, Port, PortIndex}; +use crate::{Direction, Extension, Hugr, IncomingPort, OutgoingPort, Port, PortIndex}; use super::internal::HugrMutInternals; use super::NodeMetadataMap; @@ -278,6 +279,48 @@ pub trait HugrMut: HugrMutInternals { fn extensions_mut(&mut self) -> &mut ExtensionRegistry { &mut self.hugr_mut().extensions } + + /// TODO perhaps these should be on HugrMut? + fn insert_incoming_port(&mut self, node: Node, index: usize) -> IncomingPort { + let _ = self + .add_ports(node, Direction::Incoming, 1) + .exactly_one() + .unwrap(); + + for (to, from) in (index..self.num_inputs(node)) + .map_into::() + .rev() + .tuple_windows() + { + let linked_outputs = self.linked_outputs(node, from).collect_vec(); + self.disconnect(node, from); + for (linked_node, linked_port) in linked_outputs { + self.connect(linked_node, linked_port, node, to); + } + } + index.into() + } + + /// TODO perhaps these should be on HugrMut? + fn insert_outgoing_port(&mut self, node: Node, index: usize) -> OutgoingPort { + let _ = self + .add_ports(node, Direction::Outgoing, 1) + .exactly_one() + .unwrap(); + + for (to, from) in (index..self.num_outputs(node)) + .map_into::() + .rev() + .tuple_windows() + { + let linked_inputs = self.linked_inputs(node, from).collect_vec(); + self.disconnect(node, from); + for (linked_node, linked_port) in linked_inputs { + self.connect(node, to, linked_node, linked_port); + } + } + index.into() + } } /// Records the result of inserting a Hugr or view diff --git a/hugr-core/src/hugr/internal.rs b/hugr-core/src/hugr/internal.rs index 3f1c6b6ff7..30f8727161 100644 --- a/hugr-core/src/hugr/internal.rs +++ b/hugr-core/src/hugr/internal.rs @@ -245,6 +245,13 @@ pub trait HugrMutInternals: RootTagged { } self.hugr_mut().replace_op(node, op) } + + /// TODO docs + fn get_optype_mut(&mut self, node: Node) -> Result<&mut OpType, HugrError> { + panic_invalid_node(self, node); + // TODO refuse if node == self.root() because tag might be violated + self.hugr_mut().get_optype_mut(node) + } } /// Impl for non-wrapped Hugrs. Overwrites the recursive default-impls to directly use the hugr. @@ -305,7 +312,13 @@ impl + AsMut> HugrMutInternals for T { fn replace_op(&mut self, node: Node, op: impl Into) -> Result { // We know RootHandle=Node here so no need to check - let cur = self.hugr_mut().op_types.get_mut(node.pg_index()); - Ok(std::mem::replace(cur, op.into())) + Ok(std::mem::replace( + self.hugr_mut().get_optype_mut(node)?, + op.into(), + )) + } + + fn get_optype_mut(&mut self, node: Node) -> Result<&mut OpType, HugrError> { + Ok(self.hugr_mut().op_types.get_mut(node.pg_index())) } } diff --git a/hugr-core/src/types.rs b/hugr-core/src/types.rs index 962e00876c..025666eb0a 100644 --- a/hugr-core/src/types.rs +++ b/hugr-core/src/types.rs @@ -27,7 +27,7 @@ pub use type_row::{TypeRow, TypeRowRV}; pub(crate) use poly_func::PolyFuncTypeBase; use itertools::FoldWhile::{Continue, Done}; -use itertools::{repeat_n, Itertools}; +use itertools::{repeat_n, Either, Itertools}; #[cfg(test)] use proptest_derive::Arbitrary; use serde::{Deserialize, Serialize}; @@ -256,6 +256,16 @@ impl SumType { _ => None, } } + + /// TODO docs + pub fn iter_variants(&self) -> impl Iterator { + match self { + SumType::Unit { size } => { + Either::Left(repeat_n(TypeRV::EMPTY_TYPEROW_REF, *size as usize)) + } + SumType::General { rows } => Either::Right(rows.iter()), + } + } } impl From for TypeBase { @@ -453,6 +463,14 @@ impl TypeBase { &mut self.0 } + /// TODO docs + pub fn as_sum_type(&self) -> Option<&SumType> { + match &self.0 { + TypeEnum::Sum(s) => Some(s), + _ => None, + } + } + /// Report if the type is copyable - i.e.the least upper bound of the type /// is contained by the copyable bound. pub const fn copyable(&self) -> bool { diff --git a/hugr-passes/src/non_local.rs b/hugr-passes/src/non_local.rs index efb5e7139e..5b31de9b13 100644 --- a/hugr-passes/src/non_local.rs +++ b/hugr-passes/src/non_local.rs @@ -1,11 +1,18 @@ //! This module provides functions for inspecting and modifying the nature of //! non local edges in a Hugr. +use ascent::hashbrown::HashMap; // //TODO Add `remove_nonlocal_edges` and `add_nonlocal_edges` functions -use itertools::Itertools as _; +use itertools::{Either, Itertools as _}; use thiserror::Error; -use hugr_core::{HugrView, IncomingPort, Node}; +use hugr_core::{ + builder::{ConditionalBuilder, Dataflow, DataflowSubContainer, HugrBuilder}, + hugr::hugrmut::HugrMut, + ops::{DataflowOpTrait as _, OpType, Tag, TailLoop}, + types::{EdgeKind, Type, TypeRow}, + HugrView, IncomingPort, Node, PortIndex, Wire, +}; /// Returns an iterator over all non local edges in a Hugr. /// @@ -38,12 +45,178 @@ pub fn ensure_no_nonlocal_edges(hugr: &impl HugrView) -> Result<(), NonLocalEdge } } +#[derive(Debug, Clone)] +struct WorkItem { + source: Wire, + target: (Node, IncomingPort), + ty: Type, +} + +fn thread_dataflow_parent( + hugr: &mut impl HugrMut, + parent: Node, + port_index: usize, + ty: Type, +) -> Wire { + let [i, _] = hugr.get_io(parent).unwrap(); + let OpType::Input(mut input) = hugr.get_optype(i).clone() else { + panic!("impossible") + }; + input.types.to_mut().insert(port_index, ty); + hugr.replace_op(i, input).unwrap(); + let input_port = hugr.insert_outgoing_port(i, port_index); + Wire::new(i, input_port) +} + +fn do_tailloop(hugr: &mut impl HugrMut, node: Node, source: Wire, ty: Type) -> (WorkItem, Wire) { + let mut tailloop = hugr.get_optype(node).as_tail_loop().unwrap().clone(); + let new_port_index = tailloop.just_inputs.len(); + tailloop.just_inputs.to_mut().push(ty.clone()); + hugr.replace_op(node, tailloop).unwrap(); + let tailloop_port = hugr.insert_incoming_port(node, new_port_index); + hugr.connect(source.node(), source.source(), node, tailloop_port); + let workitem = WorkItem { + source, + target: (node, tailloop_port), + ty: ty.clone(), + }; + + let input_wire = thread_dataflow_parent(hugr, node, tailloop_port.index(), ty.clone()); + + let [_, o] = hugr.get_io(node).unwrap(); + let (cond, new_control_type) = { + let Some(EdgeKind::Value(control_type)) = + hugr.get_optype(o).port_kind(IncomingPort::from(0)) + else { + panic!("impossible") + }; + let Some(sum_type) = control_type.as_sum_type() else { + panic!("impossible") + }; + + let old_sum_rows: Vec = sum_type + .iter_variants() + .map(|x| x.clone().try_into().unwrap()) + .collect_vec(); + let new_sum_rows = { + let mut v = old_sum_rows.clone(); + v[TailLoop::CONTINUE_TAG].to_mut().push(ty.clone()); + v + }; + + let new_control_type = Type::new_sum(new_sum_rows.clone()); + let mut cond = + ConditionalBuilder::new(old_sum_rows, ty.clone(), new_control_type.clone()).unwrap(); + for i in 0..2 { + let mut case = cond.case_builder(i).unwrap(); + let inputs = { + let all_inputs = case.input_wires(); + if i == TailLoop::CONTINUE_TAG { + Either::Left(all_inputs) + } else { + Either::Right(all_inputs.into_iter().dropping_back(1)) + } + }; + + let case_outputs = case + .add_dataflow_op(Tag::new(i, new_sum_rows.clone()), inputs) + .unwrap() + .outputs(); + case.finish_with_outputs(case_outputs).unwrap(); + } + (cond.finish_hugr().unwrap(), new_control_type) + }; + let cond_node = hugr.insert_hugr(node, cond).new_root; + let (n, p) = hugr.single_linked_output(o, 0).unwrap(); + hugr.connect(n, p, cond_node, 0); + hugr.connect(input_wire.node(), input_wire.source(), cond_node, 1); + hugr.disconnect(o, IncomingPort::from(0)); + hugr.connect(cond_node, 0, o, 0); + let mut output = hugr.get_optype(o).as_output().unwrap().clone(); + output.types.to_mut()[0] = new_control_type; + hugr.replace_op(o, output).unwrap(); + (workitem, input_wire) +} + +pub fn remove_nonlocal_edges(hugr: &mut impl HugrMut) -> Result<(), NonLocalEdgesError> { + let mut non_local_edges = nonlocal_edges(hugr) + .map(|target @ (node, inport)| { + let source = { + let (n, p) = hugr.single_linked_output(node, inport).unwrap(); + Wire::new(n, p) + }; + debug_assert!( + hugr.get_parent(source.node()).unwrap() != hugr.get_parent(node).unwrap() + ); + let Some(EdgeKind::Value(ty)) = hugr + .get_optype(hugr.get_parent(source.node()).unwrap()) + .port_kind(source.source()) + else { + panic!("impossible") + }; + WorkItem { source, target, ty } + }) + .collect_vec(); + + if non_local_edges.is_empty() { + return Ok(()); + } + + let mut parent_source_map = HashMap::new(); + + while let Some(WorkItem { source, target, ty }) = non_local_edges.pop() { + dbg!(&source, target, &ty); + let parent = hugr.get_parent(target.0).unwrap(); + let local_source = if hugr.get_parent(source.node()).unwrap() == parent { + &source + } else { + parent_source_map + .entry((parent, source)) + .or_insert_with(|| { + let (workitem, wire) = match hugr.get_optype(parent).clone() { + OpType::DFG(mut dfg) => { + let new_port_index = dfg.signature.input.len(); + dbg!(&dfg, new_port_index); + dfg.signature.input.to_mut().push(ty.clone()); + hugr.replace_op(parent, dfg).unwrap(); + let dfg_port = hugr.insert_incoming_port(parent, new_port_index); + hugr.connect(source.node(), source.source(), parent, dfg_port); + ( + WorkItem { + source, + target: (parent, dfg_port), + ty: ty.clone(), + }, + thread_dataflow_parent(hugr, parent, dfg_port.index(), ty), + ) + } + OpType::DataflowBlock(dataflow_block) => todo!(), + OpType::TailLoop(_) => do_tailloop(hugr, parent, source, ty), + OpType::Case(case) => todo!(), + _ => panic!("impossible"), + }; + non_local_edges.push(workitem); + wire + }) + }; + hugr.disconnect(target.0, target.1); + hugr.connect( + local_source.node(), + local_source.source(), + target.0, + target.1, + ); + } + + Ok(()) +} + #[cfg(test)] mod test { use hugr_core::{ builder::{DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer}, extension::prelude::{bool_t, Noop}, - ops::handle::NodeHandle, + ops::{handle::NodeHandle, Tag, TailLoop}, type_row, types::Signature, }; @@ -94,4 +267,66 @@ mod test { NonLocalEdgesError::Edges(vec![edge]) ); } + + #[test] + fn remove_nonlocal_edges_dfg() { + let mut hugr = { + let mut outer = DFGBuilder::new(Signature::new_endo(bool_t())).unwrap(); + let [w0] = outer.input_wires_arr(); + let [w1] = { + let inner = outer + .dfg_builder(Signature::new(type_row![], bool_t()), []) + .unwrap(); + inner.finish_with_outputs([w0]).unwrap().outputs_arr() + }; + outer.finish_hugr_with_outputs([w1]).unwrap() + }; + assert!(ensure_no_nonlocal_edges(&hugr).is_err()); + remove_nonlocal_edges(&mut hugr).unwrap(); + hugr.validate().unwrap(); + assert!(ensure_no_nonlocal_edges(&hugr).is_ok()); + } + + #[test] + fn remove_nonlocal_edges_tailloop() { + let (t1, t2, t3) = (Type::UNIT, bool_t(), Type::new_unit_sum(3)); + let mut hugr = { + let mut outer = DFGBuilder::new(Signature::new_endo(vec![ + t1.clone(), + t2.clone(), + t3.clone(), + ])) + .unwrap(); + let [s1, s2, s3] = outer.input_wires_arr(); + let [s2, s3] = { + let mut inner = outer + .tail_loop_builder( + [(t1.clone(), s1)], + [(t3.clone(), s3)], + vec![t2.clone()].into(), + ) + .unwrap(); + let [_s1, s3] = inner.input_wires_arr(); + let control = inner + .add_dataflow_op( + Tag::new( + TailLoop::BREAK_TAG, + vec![vec![t1.clone()].into(), vec![t2.clone()].into()], + ), + [s2], + ) + .unwrap() + .out_wire(0); + inner + .finish_with_outputs(control, [s3]) + .unwrap() + .outputs_arr() + }; + outer.finish_hugr_with_outputs([s1, s2, s3]).unwrap() + }; + assert!(ensure_no_nonlocal_edges(&hugr).is_err()); + remove_nonlocal_edges(&mut hugr).unwrap(); + hugr.validate().unwrap_or_else(|e| panic!("{e}")); + assert!(ensure_no_nonlocal_edges(&hugr).is_ok()); + } } From cd2348d576a7d45ff576677a3a7d262273f100d3 Mon Sep 17 00:00:00 2001 From: Douglas Wilson Date: Thu, 6 Feb 2025 14:56:08 +0000 Subject: [PATCH 02/23] add pass --- hugr-passes/Cargo.toml | 1 + hugr-passes/src/non_local.rs | 37 ++++++++++++++++++++++++++++++++--- hugr-passes/src/validation.rs | 2 +- 3 files changed, 36 insertions(+), 4 deletions(-) diff --git a/hugr-passes/Cargo.toml b/hugr-passes/Cargo.toml index c02d3d8591..39162a52e6 100644 --- a/hugr-passes/Cargo.toml +++ b/hugr-passes/Cargo.toml @@ -24,6 +24,7 @@ lazy_static = { workspace = true } paste = { workspace = true } thiserror = { workspace = true } petgraph = { workspace = true } +derive_more = { workspace = true, features = ["from", "error", "display"] } [features] extension_inference = ["hugr-core/extension_inference"] diff --git a/hugr-passes/src/non_local.rs b/hugr-passes/src/non_local.rs index 5b31de9b13..9e5af7cb85 100644 --- a/hugr-passes/src/non_local.rs +++ b/hugr-passes/src/non_local.rs @@ -4,7 +4,6 @@ use ascent::hashbrown::HashMap; // //TODO Add `remove_nonlocal_edges` and `add_nonlocal_edges` functions use itertools::{Either, Itertools as _}; -use thiserror::Error; use hugr_core::{ builder::{ConditionalBuilder, Dataflow, DataflowSubContainer, HugrBuilder}, @@ -14,6 +13,34 @@ use hugr_core::{ HugrView, IncomingPort, Node, PortIndex, Wire, }; +use crate::validation::{ValidatePassError, ValidationLevel}; + +/// TODO docs +#[derive(Debug, Clone, Default)] +pub struct UnNonLocalPass { + validation: ValidationLevel, +} + +impl UnNonLocalPass { + /// Sets the validation level used before and after the pass is run. + pub fn validation_level(mut self, level: ValidationLevel) -> Self { + self.validation = level; + self + } + + /// Run the Monomorphization pass. + fn run_no_validate(&self, hugr: &mut impl HugrMut) -> Result<(), NonLocalEdgesError> { + remove_nonlocal_edges(hugr)?; + Ok(()) + } + + /// Run the pass using specified configuration. + pub fn run(&self, hugr: &mut H) -> Result<(), NonLocalEdgesError> { + self.validation + .run_validated_pass(hugr, |hugr: &mut H, _| self.run_no_validate(hugr)) + } +} + /// Returns an iterator over all non local edges in a Hugr. /// /// All `(node, in_port)` pairs are returned where `in_port` is a value port @@ -29,10 +56,14 @@ pub fn nonlocal_edges(hugr: &impl HugrView) -> impl Iterator), + #[from] + ValidationError(ValidatePassError), } /// Verifies that there are no non local value edges in the Hugr. diff --git a/hugr-passes/src/validation.rs b/hugr-passes/src/validation.rs index baf3b86d83..5f53f403c7 100644 --- a/hugr-passes/src/validation.rs +++ b/hugr-passes/src/validation.rs @@ -23,7 +23,7 @@ pub enum ValidationLevel { WithExtensions, } -#[derive(Error, Debug)] +#[derive(Error, Debug, PartialEq)] #[allow(missing_docs)] pub enum ValidatePassError { #[error("Failed to validate input HUGR: {err}\n{pretty_hugr}")] From e04390ebefacec82b4ea1ad09fffe9285a44bad8 Mon Sep 17 00:00:00 2001 From: Douglas Wilson Date: Thu, 6 Feb 2025 15:01:07 +0000 Subject: [PATCH 03/23] oops --- hugr-passes/src/non_local.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hugr-passes/src/non_local.rs b/hugr-passes/src/non_local.rs index 9e5af7cb85..3f588d9e2d 100644 --- a/hugr-passes/src/non_local.rs +++ b/hugr-passes/src/non_local.rs @@ -180,7 +180,7 @@ pub fn remove_nonlocal_edges(hugr: &mut impl HugrMut) -> Result<(), NonLocalEdge hugr.get_parent(source.node()).unwrap() != hugr.get_parent(node).unwrap() ); let Some(EdgeKind::Value(ty)) = hugr - .get_optype(hugr.get_parent(source.node()).unwrap()) + .get_optype(source.node()) .port_kind(source.source()) else { panic!("impossible") From 903acc2eb847b4254226adf140a8cf4b5ea80bd7 Mon Sep 17 00:00:00 2001 From: Douglas Wilson Date: Thu, 6 Feb 2025 15:42:54 +0000 Subject: [PATCH 04/23] conditional --- hugr-passes/src/non_local.rs | 129 +++++++++++++++++++++++++++-------- 1 file changed, 99 insertions(+), 30 deletions(-) diff --git a/hugr-passes/src/non_local.rs b/hugr-passes/src/non_local.rs index 3f588d9e2d..455b9e90fd 100644 --- a/hugr-passes/src/non_local.rs +++ b/hugr-passes/src/non_local.rs @@ -199,36 +199,62 @@ pub fn remove_nonlocal_edges(hugr: &mut impl HugrMut) -> Result<(), NonLocalEdge dbg!(&source, target, &ty); let parent = hugr.get_parent(target.0).unwrap(); let local_source = if hugr.get_parent(source.node()).unwrap() == parent { - &source + source + } else if let Some(wire) = parent_source_map.get(&(parent,source)) { + *wire } else { - parent_source_map - .entry((parent, source)) - .or_insert_with(|| { - let (workitem, wire) = match hugr.get_optype(parent).clone() { - OpType::DFG(mut dfg) => { - let new_port_index = dfg.signature.input.len(); - dbg!(&dfg, new_port_index); - dfg.signature.input.to_mut().push(ty.clone()); - hugr.replace_op(parent, dfg).unwrap(); - let dfg_port = hugr.insert_incoming_port(parent, new_port_index); - hugr.connect(source.node(), source.source(), parent, dfg_port); - ( - WorkItem { - source, - target: (parent, dfg_port), - ty: ty.clone(), - }, - thread_dataflow_parent(hugr, parent, dfg_port.index(), ty), - ) + let (workitem, wire) = match hugr.get_optype(parent).clone() { + OpType::DFG(mut dfg) => { + let new_port_index = dfg.signature.input.len(); + dbg!(&dfg, new_port_index); + dfg.signature.input.to_mut().push(ty.clone()); + hugr.replace_op(parent, dfg).unwrap(); + let dfg_port = hugr.insert_incoming_port(parent, new_port_index); + hugr.connect(source.node(), source.source(), parent, dfg_port); + let wire = thread_dataflow_parent(hugr, parent, dfg_port.index(), ty.clone()); + let _ = parent_source_map.insert((parent, source), wire); + ( + WorkItem { + source, + target: (parent, dfg_port), + ty + }, + wire + ) + } + OpType::DataflowBlock(dataflow_block) => todo!(), + OpType::TailLoop(_) => { + let (workitem, wire) = do_tailloop(hugr, parent, source, ty); + let _ = parent_source_map.insert((parent, source), wire); + (workitem, wire) + } + OpType::Case(_) => { + let cond_node = hugr.get_parent(parent).unwrap(); + let mut cond = hugr.get_optype(cond_node).as_conditional().unwrap().clone(); + let new_port_index = cond.signature().input().len(); + cond.other_inputs.to_mut().push(ty.clone()); + hugr.replace_op(cond_node, cond).unwrap(); + let cond_port = hugr.insert_incoming_port(cond_node, new_port_index); + let mut this_wire = None; + for (case_n, mut case) in hugr.children(cond_node).filter_map(|n| { + let case = hugr.get_optype(n).as_case()?; + Some((n, case.clone())) + }).collect_vec() { + let case_port_index = case.signature.input().len(); + case.signature.input.to_mut().push(ty.clone()); + hugr.replace_op(case_n, case).unwrap(); + let case_input_wire = thread_dataflow_parent(hugr, case_n, case_port_index, ty.clone()); + let _ = parent_source_map.insert((case_n, source), case_input_wire); + if case_n == parent { + this_wire = Some(case_input_wire); } - OpType::DataflowBlock(dataflow_block) => todo!(), - OpType::TailLoop(_) => do_tailloop(hugr, parent, source, ty), - OpType::Case(case) => todo!(), - _ => panic!("impossible"), - }; - non_local_edges.push(workitem); - wire - }) + } + (WorkItem { source, target: (cond_node, cond_port), ty }, this_wire.unwrap()) + } + _ => panic!("impossible"), + }; + non_local_edges.push(workitem); + wire }; hugr.disconnect(target.0, target.1); hugr.connect( @@ -245,9 +271,9 @@ pub fn remove_nonlocal_edges(hugr: &mut impl HugrMut) -> Result<(), NonLocalEdge #[cfg(test)] mod test { use hugr_core::{ - builder::{DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer}, + builder::{DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, SubContainer}, extension::prelude::{bool_t, Noop}, - ops::{handle::NodeHandle, Tag, TailLoop}, + ops::{handle::NodeHandle, Tag, TailLoop, Value}, type_row, types::Signature, }; @@ -360,4 +386,47 @@ mod test { hugr.validate().unwrap_or_else(|e| panic!("{e}")); assert!(ensure_no_nonlocal_edges(&hugr).is_ok()); } + + #[test] + fn remove_nonlocal_edges_cond() { + let (t1, t2, t3) = (Type::UNIT, bool_t(), Type::new_unit_sum(3)); + let out_variants = vec![t1.clone().into(), t2.clone().into()]; + let out_type = Type::new_sum(out_variants.clone()); + let mut hugr = { + let mut outer = DFGBuilder::new(Signature::new(vec![ + t1.clone(), + t2.clone(), + t3.clone() + ], out_type.clone())) + .unwrap(); + let [s1, s2, s3] = outer.input_wires_arr(); + let [out] = { + let mut cond = outer + .conditional_builder((vec![type_row![];3], s3), [], out_type.into()).unwrap(); + + { + let mut case = cond.case_builder(0).unwrap(); + let [r] = case.add_dataflow_op(Tag::new(0, out_variants.clone()), [s1]).unwrap().outputs_arr(); + case.finish_with_outputs([r]).unwrap(); + } + { + let mut case = cond.case_builder(1).unwrap(); + let [r] = case.add_dataflow_op(Tag::new(1, out_variants.clone()), [s2]).unwrap().outputs_arr(); + case.finish_with_outputs([r]).unwrap(); + } + { + let mut case = cond.case_builder(2).unwrap(); + let u = case.add_load_value(Value::unit()); + let [r] = case.add_dataflow_op(Tag::new(0, out_variants.clone()), [u]).unwrap().outputs_arr(); + case.finish_with_outputs([r]).unwrap(); + } + cond.finish_sub_container().unwrap().outputs_arr() + }; + outer.finish_hugr_with_outputs([out]).unwrap() + }; + assert!(ensure_no_nonlocal_edges(&hugr).is_err()); + remove_nonlocal_edges(&mut hugr).unwrap(); + hugr.validate().unwrap_or_else(|e| panic!("{e}")); + assert!(ensure_no_nonlocal_edges(&hugr).is_ok()); + } } From a87d705f5156a13c0559075d02ab3308bbd0c19c Mon Sep 17 00:00:00 2001 From: Douglas Wilson Date: Thu, 6 Feb 2025 15:44:58 +0000 Subject: [PATCH 05/23] remove dbg --- hugr-passes/src/non_local.rs | 2 -- 1 file changed, 2 deletions(-) diff --git a/hugr-passes/src/non_local.rs b/hugr-passes/src/non_local.rs index 455b9e90fd..1bfa764374 100644 --- a/hugr-passes/src/non_local.rs +++ b/hugr-passes/src/non_local.rs @@ -196,7 +196,6 @@ pub fn remove_nonlocal_edges(hugr: &mut impl HugrMut) -> Result<(), NonLocalEdge let mut parent_source_map = HashMap::new(); while let Some(WorkItem { source, target, ty }) = non_local_edges.pop() { - dbg!(&source, target, &ty); let parent = hugr.get_parent(target.0).unwrap(); let local_source = if hugr.get_parent(source.node()).unwrap() == parent { source @@ -206,7 +205,6 @@ pub fn remove_nonlocal_edges(hugr: &mut impl HugrMut) -> Result<(), NonLocalEdge let (workitem, wire) = match hugr.get_optype(parent).clone() { OpType::DFG(mut dfg) => { let new_port_index = dfg.signature.input.len(); - dbg!(&dfg, new_port_index); dfg.signature.input.to_mut().push(ty.clone()); hugr.replace_op(parent, dfg).unwrap(); let dfg_port = hugr.insert_incoming_port(parent, new_port_index); From 11879a8efdb863dedb357d5e0fbfaca19b5a7875 Mon Sep 17 00:00:00 2001 From: Douglas Wilson Date: Fri, 7 Feb 2025 13:55:02 +0000 Subject: [PATCH 06/23] refactor --- hugr-passes/src/non_local.rs | 481 +++++++++++++++++++++++++---------- 1 file changed, 353 insertions(+), 128 deletions(-) diff --git a/hugr-passes/src/non_local.rs b/hugr-passes/src/non_local.rs index 1bfa764374..26584d5a34 100644 --- a/hugr-passes/src/non_local.rs +++ b/hugr-passes/src/non_local.rs @@ -1,16 +1,23 @@ //! This module provides functions for inspecting and modifying the nature of //! non local edges in a Hugr. -use ascent::hashbrown::HashMap; -// +use std::{ + collections::{BTreeMap, HashMap, HashSet, VecDeque}, + iter, +}; + //TODO Add `remove_nonlocal_edges` and `add_nonlocal_edges` functions use itertools::{Either, Itertools as _}; use hugr_core::{ builder::{ConditionalBuilder, Dataflow, DataflowSubContainer, HugrBuilder}, - hugr::hugrmut::HugrMut, + hugr::{ + hugrmut::HugrMut, + views::{DescendantsGraph, HierarchyView}, + HugrError, + }, ops::{DataflowOpTrait as _, OpType, Tag, TailLoop}, types::{EdgeKind, Type, TypeRow}, - HugrView, IncomingPort, Node, PortIndex, Wire, + HugrView, IncomingPort, Node, Wire, }; use crate::validation::{ValidatePassError, ValidationLevel}; @@ -21,7 +28,7 @@ pub struct UnNonLocalPass { validation: ValidationLevel, } -impl UnNonLocalPass { +impl UnNonLocalPass { /// Sets the validation level used before and after the pass is run. pub fn validation_level(mut self, level: ValidationLevel) -> Self { self.validation = level; @@ -30,7 +37,8 @@ impl UnNonLocalPass { /// Run the Monomorphization pass. fn run_no_validate(&self, hugr: &mut impl HugrMut) -> Result<(), NonLocalEdgesError> { - remove_nonlocal_edges(hugr)?; + let root = hugr.root(); + remove_nonlocal_edges(hugr, root)?; Ok(()) } @@ -64,6 +72,8 @@ pub enum NonLocalEdgesError { Edges(Vec<(Node, IncomingPort)>), #[from] ValidationError(ValidatePassError), + #[from] + HugrError(HugrError), } /// Verifies that there are no non local value edges in the Hugr. @@ -86,33 +96,51 @@ struct WorkItem { fn thread_dataflow_parent( hugr: &mut impl HugrMut, parent: Node, - port_index: usize, - ty: Type, -) -> Wire { - let [i, _] = hugr.get_io(parent).unwrap(); - let OpType::Input(mut input) = hugr.get_optype(i).clone() else { + start_port_index: usize, + types: Vec, +) -> impl Iterator { + let [input_n, _] = hugr.get_io(parent).unwrap(); + let OpType::Input(mut input) = hugr.get_optype(input_n).clone() else { panic!("impossible") }; - input.types.to_mut().insert(port_index, ty); - hugr.replace_op(i, input).unwrap(); - let input_port = hugr.insert_outgoing_port(i, port_index); - Wire::new(i, input_port) + let mut r = vec![]; + for (i, ty) in types.into_iter().enumerate() { + input + .types + .to_mut() + .insert(start_port_index + i, ty.clone()); + r.push(Wire::new( + input_n, + hugr.insert_outgoing_port(input_n, start_port_index + i), + )); + } + hugr.replace_op(input_n, input).unwrap(); + r.into_iter() } -fn do_tailloop(hugr: &mut impl HugrMut, node: Node, source: Wire, ty: Type) -> (WorkItem, Wire) { +fn do_tailloop( + parent_source_map: &mut ParentSourceMap, + hugr: &mut impl HugrMut, + node: Node, + sources: impl IntoIterator, +) -> impl Iterator { + let (sources, types): (Vec<_>, Vec<_>) = sources.into_iter().unzip(); let mut tailloop = hugr.get_optype(node).as_tail_loop().unwrap().clone(); - let new_port_index = tailloop.just_inputs.len(); - tailloop.just_inputs.to_mut().push(ty.clone()); - hugr.replace_op(node, tailloop).unwrap(); - let tailloop_port = hugr.insert_incoming_port(node, new_port_index); - hugr.connect(source.node(), source.source(), node, tailloop_port); - let workitem = WorkItem { - source, - target: (node, tailloop_port), - ty: ty.clone(), - }; + let start_port_index = tailloop.just_inputs.len(); + { + tailloop.just_inputs.to_mut().extend(types.iter().cloned()); + hugr.replace_op(node, tailloop).unwrap(); + } + let tailloop_ports = (0..sources.len()) + .map(|i| hugr.insert_incoming_port(node, start_port_index + i)) + .collect_vec(); - let input_wire = thread_dataflow_parent(hugr, node, tailloop_port.index(), ty.clone()); + let input_wires = + thread_dataflow_parent(hugr, node, start_port_index, types.clone()).collect_vec(); + parent_source_map.insert( + node, + iter::zip(sources.iter().copied(), input_wires.iter().copied()), + ); let [_, o] = hugr.get_io(node).unwrap(); let (cond, new_control_type) = { @@ -131,13 +159,15 @@ fn do_tailloop(hugr: &mut impl HugrMut, node: Node, source: Wire, ty: Type) -> ( .collect_vec(); let new_sum_rows = { let mut v = old_sum_rows.clone(); - v[TailLoop::CONTINUE_TAG].to_mut().push(ty.clone()); + v[TailLoop::CONTINUE_TAG] + .to_mut() + .extend(types.iter().cloned()); v }; let new_control_type = Type::new_sum(new_sum_rows.clone()); let mut cond = - ConditionalBuilder::new(old_sum_rows, ty.clone(), new_control_type.clone()).unwrap(); + ConditionalBuilder::new(old_sum_rows, types.clone(), new_control_type.clone()).unwrap(); for i in 0..2 { let mut case = cond.case_builder(i).unwrap(); let inputs = { @@ -145,7 +175,7 @@ fn do_tailloop(hugr: &mut impl HugrMut, node: Node, source: Wire, ty: Type) -> ( if i == TailLoop::CONTINUE_TAG { Either::Left(all_inputs) } else { - Either::Right(all_inputs.into_iter().dropping_back(1)) + Either::Right(all_inputs.into_iter().dropping_back(types.len())) } }; @@ -160,107 +190,290 @@ fn do_tailloop(hugr: &mut impl HugrMut, node: Node, source: Wire, ty: Type) -> ( let cond_node = hugr.insert_hugr(node, cond).new_root; let (n, p) = hugr.single_linked_output(o, 0).unwrap(); hugr.connect(n, p, cond_node, 0); - hugr.connect(input_wire.node(), input_wire.source(), cond_node, 1); + for (i, w) in input_wires.into_iter().enumerate() { + hugr.connect(w.node(), w.source(), cond_node, i + 1); + } hugr.disconnect(o, IncomingPort::from(0)); hugr.connect(cond_node, 0, o, 0); let mut output = hugr.get_optype(o).as_output().unwrap().clone(); output.types.to_mut()[0] = new_control_type; hugr.replace_op(o, output).unwrap(); - (workitem, input_wire) + mk_workitems(node, sources, tailloop_ports, types) } -pub fn remove_nonlocal_edges(hugr: &mut impl HugrMut) -> Result<(), NonLocalEdgesError> { - let mut non_local_edges = nonlocal_edges(hugr) - .map(|target @ (node, inport)| { - let source = { - let (n, p) = hugr.single_linked_output(node, inport).unwrap(); - Wire::new(n, p) - }; - debug_assert!( - hugr.get_parent(source.node()).unwrap() != hugr.get_parent(node).unwrap() +#[derive(Clone, Default, Debug)] +struct ParentSourceMap(HashMap>); + +impl ParentSourceMap { + fn contains_parent(&self, parent: Node) -> bool { + self.0.contains_key(&parent) + } + + fn insert(&mut self, parent: Node, sources: impl IntoIterator) { + debug_assert!(!self.0.contains_key(&parent)); + self.0.entry(parent).or_default().extend(sources); + } + + fn get(&self, parent: Node, source: Wire) -> Option { + self.0.get(&parent).and_then(|m| m.get(&source).cloned()) + } +} + +fn mk_workitems( + node: Node, + sources: impl IntoIterator, + ports: impl IntoIterator, + types: impl IntoIterator, +) -> impl Iterator { + itertools::izip!(sources, ports, types).map(move |(source, p, ty)| WorkItem { + source, + target: (node, p), + ty, + }) +} + +fn thread_sources( + parent_source_map: &mut ParentSourceMap, + hugr: &mut impl HugrMut, + bb: Node, + sources: impl IntoIterator, +) -> Vec { + let (source_wires, types): (Vec<_>, Vec<_>) = sources.into_iter().unzip(); + match hugr.get_optype(bb).clone() { + OpType::DFG(mut dfg) => { + debug_assert!(!parent_source_map.contains_parent(bb)); + let start_new_port_index = dfg.signature.input().len(); + let new_dfg_ports = (0..source_wires.len()) + .map(|i| hugr.insert_incoming_port(bb, start_new_port_index + i)) + .collect_vec(); + dfg.signature.input.to_mut().extend(types.clone()); + hugr.replace_op(bb, dfg).unwrap(); + for (source, &target) in iter::zip(source_wires.iter(), new_dfg_ports.iter()) { + hugr.connect(source.node(), source.source(), bb, target); + } + parent_source_map.insert( + bb, + iter::zip( + source_wires.iter().copied(), + thread_dataflow_parent(hugr, bb, start_new_port_index, types.clone()), + ), ); - let Some(EdgeKind::Value(ty)) = hugr - .get_optype(source.node()) - .port_kind(source.source()) - else { - panic!("impossible") - }; - WorkItem { source, target, ty } - }) - .collect_vec(); + mk_workitems(bb, source_wires, new_dfg_ports, types).collect_vec() + } + OpType::Conditional(mut cond) => { + debug_assert!(!parent_source_map.contains_parent(bb)); + let start_new_port_index = cond.signature().input().len(); + cond.other_inputs.to_mut().extend(types.clone()); + hugr.replace_op(bb, cond).unwrap(); + let new_cond_ports = (0..source_wires.len()) + .map(|i| hugr.insert_incoming_port(bb, start_new_port_index + i)) + .collect_vec(); + parent_source_map.insert(bb, iter::empty()); + mk_workitems(bb, source_wires, new_cond_ports, types).collect_vec() + } + OpType::Case(mut case) => { + debug_assert!(!parent_source_map.contains_parent(bb)); + let start_case_port_index = case.signature.input().len(); + case.signature.input.to_mut().extend(types.clone()); + hugr.replace_op(bb, case).unwrap(); + parent_source_map.insert( + bb, + iter::zip( + source_wires.iter().copied(), + thread_dataflow_parent(hugr, bb, start_case_port_index, types), + ), + ); + vec![] + } + OpType::TailLoop(_) => { + do_tailloop(parent_source_map, hugr, bb, iter::zip(source_wires, types)).collect_vec() + } + _ => panic!("impossible"), + } + // _ => panic!("impossible"), + // }; + // non_local_edges.push(workitem); + // wire + // }; + // hugr.disconnect(target.0, target.1); + // hugr.connect( + // local_source.node(), + // local_source.source(), + // target.0, + // target.1, + // ); + // } +} - if non_local_edges.is_empty() { - return Ok(()); +#[derive(Debug, Default, Clone)] +struct BBNeedsSourcesMapBuilder(HashMap>); + +impl BBNeedsSourcesMapBuilder { + fn insert(&mut self, bb: Node, source: Wire, ty: Type) { + self.0.entry(bb).or_default().insert(source, ty); } - let mut parent_source_map = HashMap::new(); + fn extend_parent_needs_for(&mut self, ref hugr: impl HugrView, child: Node) -> bool { + let parent = hugr.get_parent(child).unwrap(); + let parent_needs = self + .0 + .get(&child) + .into_iter() + .flat_map(move |m| { + m.iter().filter(move |(w, _)| hugr.get_parent(w.node()).unwrap() != parent) + .map(|(&w, ty)| (w, ty.clone())) + }) + .collect_vec(); + let any = !parent_needs.is_empty(); + if any { + self.0.entry(parent).or_default().extend(parent_needs); + } + any + } - while let Some(WorkItem { source, target, ty }) = non_local_edges.pop() { - let parent = hugr.get_parent(target.0).unwrap(); - let local_source = if hugr.get_parent(source.node()).unwrap() == parent { - source - } else if let Some(wire) = parent_source_map.get(&(parent,source)) { - *wire - } else { - let (workitem, wire) = match hugr.get_optype(parent).clone() { - OpType::DFG(mut dfg) => { - let new_port_index = dfg.signature.input.len(); - dfg.signature.input.to_mut().push(ty.clone()); - hugr.replace_op(parent, dfg).unwrap(); - let dfg_port = hugr.insert_incoming_port(parent, new_port_index); - hugr.connect(source.node(), source.source(), parent, dfg_port); - let wire = thread_dataflow_parent(hugr, parent, dfg_port.index(), ty.clone()); - let _ = parent_source_map.insert((parent, source), wire); - ( - WorkItem { - source, - target: (parent, dfg_port), - ty - }, - wire - ) - } - OpType::DataflowBlock(dataflow_block) => todo!(), - OpType::TailLoop(_) => { - let (workitem, wire) = do_tailloop(hugr, parent, source, ty); - let _ = parent_source_map.insert((parent, source), wire); - (workitem, wire) - } - OpType::Case(_) => { - let cond_node = hugr.get_parent(parent).unwrap(); - let mut cond = hugr.get_optype(cond_node).as_conditional().unwrap().clone(); - let new_port_index = cond.signature().input().len(); - cond.other_inputs.to_mut().push(ty.clone()); - hugr.replace_op(cond_node, cond).unwrap(); - let cond_port = hugr.insert_incoming_port(cond_node, new_port_index); - let mut this_wire = None; - for (case_n, mut case) in hugr.children(cond_node).filter_map(|n| { - let case = hugr.get_optype(n).as_case()?; - Some((n, case.clone())) - }).collect_vec() { - let case_port_index = case.signature.input().len(); - case.signature.input.to_mut().push(ty.clone()); - hugr.replace_op(case_n, case).unwrap(); - let case_input_wire = thread_dataflow_parent(hugr, case_n, case_port_index, ty.clone()); - let _ = parent_source_map.insert((case_n, source), case_input_wire); - if case_n == parent { - this_wire = Some(case_input_wire); - } - } - (WorkItem { source, target: (cond_node, cond_port), ty }, this_wire.unwrap()) + fn finish(mut self, hugr: impl HugrView) -> HashMap> { + let conds = self + .0 + .keys() + .copied() + .filter(|&n| hugr.get_optype(n).is_conditional()) + .collect_vec(); + for cond in conds { + if hugr.get_optype(cond).is_conditional() { + let cases = hugr + .children(cond) + .filter(|&child| hugr.get_optype(child).is_case()) + .collect_vec(); + let all_needed: BTreeMap<_, _> = cases + .iter() + .flat_map(|&case| { + let case_needed = self.0.get(&case); + case_needed + .into_iter() + .flat_map(|m| m.iter().map(|(&w, ty)| (w, ty.clone()))) + }) + .collect(); + for case in cases { + let _ = self.0.insert(case, all_needed.clone()); } - _ => panic!("impossible"), + } + } + self.0 + } +} + +pub fn remove_nonlocal_edges( + hugr: &mut impl HugrMut, + root: Node, +) -> Result<(), NonLocalEdgesError> { + let nonlocal_edges_map: HashMap = + nonlocal_edges(&DescendantsGraph::::try_new(hugr, root)?) + .map(|target @ (node, inport)| { + let source = { + let (n, p) = hugr.single_linked_output(node, inport).unwrap(); + Wire::new(n, p) + }; + debug_assert!( + hugr.get_parent(source.node()).unwrap() != hugr.get_parent(node).unwrap() + ); + let Some(EdgeKind::Value(ty)) = + hugr.get_optype(source.node()).port_kind(source.source()) + else { + panic!("impossible") + }; + (node, WorkItem { source, target, ty }) + }) + .collect(); + + if nonlocal_edges_map.is_empty() { + return Ok(()); + } + + let bb_needs_sources_map = { + let nonlocal_sorted = { + let mut v = iter::successors(Some(vec![root]), |nodes| { + let children = nodes + .iter() + .flat_map(|&n| hugr.children(n)) + .collect_vec(); + (!children.is_empty()).then_some(children) + }) + .flatten() + .filter_map(|n| nonlocal_edges_map.get(&n)) + .collect_vec(); + v.reverse(); + v + }; + let mut parent_set = HashSet::::new(); + // earlier items are deeper in the heirarchy + let mut parent_worklist = VecDeque::::new(); + let mut add_parent = |p, wl: &mut VecDeque<_>| { + if parent_set.insert(p) { + wl.push_back(p); + } + }; + let mut bnsm = BBNeedsSourcesMapBuilder::default(); + for workitem in nonlocal_sorted { + let parent = hugr.get_parent(workitem.target.0).unwrap(); + debug_assert!(hugr.get_parent(parent).is_some()); + bnsm.insert(parent, workitem.source, workitem.ty.clone()); + add_parent(parent, &mut parent_worklist); + } + + while let Some(bb_node) = parent_worklist.pop_front() { + let Some(parent) = hugr.get_parent(bb_node) else { + continue; }; - non_local_edges.push(workitem); - wire + if bnsm.extend_parent_needs_for(&hugr, bb_node) { + add_parent(parent, &mut parent_worklist); + } + } + bnsm.finish(&hugr) + }; + + #[cfg(debug_assertions)] + { + for (&n, wi) in nonlocal_edges_map.iter() { + let mut m = n; + loop { + let parent = hugr.get_parent(m).unwrap(); + if hugr.get_parent(wi.source.node()).unwrap() == parent { + break; + } + assert!(bb_needs_sources_map[&parent].contains_key(&wi.source)); + m = parent; + } + } + + for &bb in bb_needs_sources_map.keys() { + assert!(hugr.get_parent(bb).is_some()); + } + } + + let mut worklist = nonlocal_edges_map.into_values().collect_vec(); + let mut parent_source_map = ParentSourceMap::default(); + + for (bb, needs_sources) in bb_needs_sources_map { + worklist.extend(thread_sources( + &mut parent_source_map, + hugr, + bb, + needs_sources, + )); + } + + let parent_source_map = parent_source_map; + + while let Some(wi) = worklist.pop() { + let parent = hugr.get_parent(wi.target.0).unwrap(); + let source = if hugr.get_parent(wi.source.node()).unwrap() == parent { + wi.source + } else { + parent_source_map.get(parent, wi.source).unwrap() }; - hugr.disconnect(target.0, target.1); - hugr.connect( - local_source.node(), - local_source.source(), - target.0, - target.1, - ); + debug_assert_eq!(hugr.get_parent(source.node()), hugr.get_parent(wi.target.0)); + hugr.disconnect(wi.target.0, wi.target.1); + hugr.connect(source.node(), source.source(), wi.target.0, wi.target.1); } Ok(()) @@ -337,7 +550,8 @@ mod test { outer.finish_hugr_with_outputs([w1]).unwrap() }; assert!(ensure_no_nonlocal_edges(&hugr).is_err()); - remove_nonlocal_edges(&mut hugr).unwrap(); + let root = hugr.root(); + remove_nonlocal_edges(&mut hugr, root).unwrap(); hugr.validate().unwrap(); assert!(ensure_no_nonlocal_edges(&hugr).is_ok()); } @@ -380,7 +594,8 @@ mod test { outer.finish_hugr_with_outputs([s1, s2, s3]).unwrap() }; assert!(ensure_no_nonlocal_edges(&hugr).is_err()); - remove_nonlocal_edges(&mut hugr).unwrap(); + let root = hugr.root(); + remove_nonlocal_edges(&mut hugr, root).unwrap(); hugr.validate().unwrap_or_else(|e| panic!("{e}")); assert!(ensure_no_nonlocal_edges(&hugr).is_ok()); } @@ -391,31 +606,40 @@ mod test { let out_variants = vec![t1.clone().into(), t2.clone().into()]; let out_type = Type::new_sum(out_variants.clone()); let mut hugr = { - let mut outer = DFGBuilder::new(Signature::new(vec![ - t1.clone(), - t2.clone(), - t3.clone() - ], out_type.clone())) + let mut outer = DFGBuilder::new(Signature::new( + vec![t1.clone(), t2.clone(), t3.clone()], + out_type.clone(), + )) .unwrap(); let [s1, s2, s3] = outer.input_wires_arr(); let [out] = { let mut cond = outer - .conditional_builder((vec![type_row![];3], s3), [], out_type.into()).unwrap(); + .conditional_builder((vec![type_row![]; 3], s3), [], out_type.into()) + .unwrap(); { let mut case = cond.case_builder(0).unwrap(); - let [r] = case.add_dataflow_op(Tag::new(0, out_variants.clone()), [s1]).unwrap().outputs_arr(); + let [r] = case + .add_dataflow_op(Tag::new(0, out_variants.clone()), [s1]) + .unwrap() + .outputs_arr(); case.finish_with_outputs([r]).unwrap(); } { let mut case = cond.case_builder(1).unwrap(); - let [r] = case.add_dataflow_op(Tag::new(1, out_variants.clone()), [s2]).unwrap().outputs_arr(); + let [r] = case + .add_dataflow_op(Tag::new(1, out_variants.clone()), [s2]) + .unwrap() + .outputs_arr(); case.finish_with_outputs([r]).unwrap(); } { let mut case = cond.case_builder(2).unwrap(); let u = case.add_load_value(Value::unit()); - let [r] = case.add_dataflow_op(Tag::new(0, out_variants.clone()), [u]).unwrap().outputs_arr(); + let [r] = case + .add_dataflow_op(Tag::new(0, out_variants.clone()), [u]) + .unwrap() + .outputs_arr(); case.finish_with_outputs([r]).unwrap(); } cond.finish_sub_container().unwrap().outputs_arr() @@ -423,7 +647,8 @@ mod test { outer.finish_hugr_with_outputs([out]).unwrap() }; assert!(ensure_no_nonlocal_edges(&hugr).is_err()); - remove_nonlocal_edges(&mut hugr).unwrap(); + let root = hugr.root(); + remove_nonlocal_edges(&mut hugr, root).unwrap(); hugr.validate().unwrap_or_else(|e| panic!("{e}")); assert!(ensure_no_nonlocal_edges(&hugr).is_ok()); } From 5f09ba96fbd0b199c5c368625a4605124296978a Mon Sep 17 00:00:00 2001 From: Douglas Wilson Date: Fri, 7 Feb 2025 17:18:31 +0000 Subject: [PATCH 07/23] works --- hugr-passes/Cargo.toml | 1 + hugr-passes/src/non_local.rs | 654 +++++++++++++++++++++++------------ 2 files changed, 441 insertions(+), 214 deletions(-) diff --git a/hugr-passes/Cargo.toml b/hugr-passes/Cargo.toml index 39162a52e6..fe7db4be5a 100644 --- a/hugr-passes/Cargo.toml +++ b/hugr-passes/Cargo.toml @@ -25,6 +25,7 @@ paste = { workspace = true } thiserror = { workspace = true } petgraph = { workspace = true } derive_more = { workspace = true, features = ["from", "error", "display"] } +delegate.workspace = true [features] extension_inference = ["hugr-core/extension_inference"] diff --git a/hugr-passes/src/non_local.rs b/hugr-passes/src/non_local.rs index 26584d5a34..22bef220e8 100644 --- a/hugr-passes/src/non_local.rs +++ b/hugr-passes/src/non_local.rs @@ -1,8 +1,9 @@ //! This module provides functions for inspecting and modifying the nature of //! non local edges in a Hugr. +use delegate::delegate; use std::{ collections::{BTreeMap, HashMap, HashSet, VecDeque}, - iter, + iter, mem, }; //TODO Add `remove_nonlocal_edges` and `add_nonlocal_edges` functions @@ -17,7 +18,7 @@ use hugr_core::{ }, ops::{DataflowOpTrait as _, OpType, Tag, TailLoop}, types::{EdgeKind, Type, TypeRow}, - HugrView, IncomingPort, Node, Wire, + HugrView, IncomingPort, Node, PortIndex, Wire, }; use crate::validation::{ValidatePassError, ValidationLevel}; @@ -93,220 +94,375 @@ struct WorkItem { ty: Type, } -fn thread_dataflow_parent( - hugr: &mut impl HugrMut, - parent: Node, - start_port_index: usize, - types: Vec, -) -> impl Iterator { - let [input_n, _] = hugr.get_io(parent).unwrap(); - let OpType::Input(mut input) = hugr.get_optype(input_n).clone() else { - panic!("impossible") - }; - let mut r = vec![]; - for (i, ty) in types.into_iter().enumerate() { - input - .types - .to_mut() - .insert(start_port_index + i, ty.clone()); - r.push(Wire::new( - input_n, - hugr.insert_outgoing_port(input_n, start_port_index + i), - )); +#[derive(Clone, Default, Debug)] +struct ParentSourceMap(HashMap>); + +impl ParentSourceMap { + // fn contains_parent(&self, parent: Node) -> bool { + // self.0.contains_key(&parent) + // } + + fn insert_sources_in_parent( + &mut self, + parent: Node, + sources: impl IntoIterator, + ) { + debug_assert!(!self.0.contains_key(&parent)); + self.0.entry(parent).or_default().extend(sources); } - hugr.replace_op(input_n, input).unwrap(); - r.into_iter() -} -fn do_tailloop( - parent_source_map: &mut ParentSourceMap, - hugr: &mut impl HugrMut, - node: Node, - sources: impl IntoIterator, -) -> impl Iterator { - let (sources, types): (Vec<_>, Vec<_>) = sources.into_iter().unzip(); - let mut tailloop = hugr.get_optype(node).as_tail_loop().unwrap().clone(); - let start_port_index = tailloop.just_inputs.len(); - { - tailloop.just_inputs.to_mut().extend(types.iter().cloned()); - hugr.replace_op(node, tailloop).unwrap(); + fn get_source_in_parent(&self, parent: Node, source: Wire) -> Option { + self.0.get(&parent).and_then(|m| m.get(&source).cloned()) } - let tailloop_ports = (0..sources.len()) - .map(|i| hugr.insert_incoming_port(node, start_port_index + i)) - .collect_vec(); - - let input_wires = - thread_dataflow_parent(hugr, node, start_port_index, types.clone()).collect_vec(); - parent_source_map.insert( - node, - iter::zip(sources.iter().copied(), input_wires.iter().copied()), - ); - - let [_, o] = hugr.get_io(node).unwrap(); - let (cond, new_control_type) = { - let Some(EdgeKind::Value(control_type)) = - hugr.get_optype(o).port_kind(IncomingPort::from(0)) - else { - panic!("impossible") - }; - let Some(sum_type) = control_type.as_sum_type() else { + + fn thread_dataflow_parent( + &mut self, + hugr: &mut impl HugrMut, + parent: Node, + start_port_index: usize, + sources: impl IntoIterator, + ) -> impl Iterator { + let [input_n, _] = hugr.get_io(parent).unwrap(); + let OpType::Input(mut input) = hugr.get_optype(input_n).clone() else { panic!("impossible") }; + let mut input_wires = vec![]; + self.0 + .entry(parent) + .or_default() + .extend(sources.into_iter().enumerate().map(|(i, (source, ty))| { + input.types.to_mut().insert(start_port_index + i, ty); + let input_wire = Wire::new( + input_n, + hugr.insert_outgoing_port(input_n, start_port_index + i), + ); + input_wires.push(input_wire); + (source, input_wire) + })); + hugr.replace_op(input_n, input).unwrap(); + input_wires.into_iter() + } +} - let old_sum_rows: Vec = sum_type - .iter_variants() - .map(|x| x.clone().try_into().unwrap()) - .collect_vec(); - let new_sum_rows = { - let mut v = old_sum_rows.clone(); - v[TailLoop::CONTINUE_TAG] - .to_mut() - .extend(types.iter().cloned()); - v - }; +#[derive(Clone, Debug)] +struct ThreadState<'a> { + parent_source_map: ParentSourceMap, + needs: &'a BBNeedsSourcesMap, + worklist: Vec, +} - let new_control_type = Type::new_sum(new_sum_rows.clone()); - let mut cond = - ConditionalBuilder::new(old_sum_rows, types.clone(), new_control_type.clone()).unwrap(); - for i in 0..2 { - let mut case = cond.case_builder(i).unwrap(); - let inputs = { - let all_inputs = case.input_wires(); - if i == TailLoop::CONTINUE_TAG { - Either::Left(all_inputs) +impl<'a> ThreadState<'a> { + delegate! { + to self.parent_source_map { + // fn contains_parent(&self, parent: Node) -> bool; + fn get_source_in_parent(&self, parent: Node, source: Wire) -> Option; + fn insert_sources_in_parent(&mut self, parent: Node, sources: impl IntoIterator); + fn thread_dataflow_parent( + &mut self, + hugr: &mut impl HugrMut, + parent: Node, + start_port_index: usize, + sources: impl IntoIterator, + ) -> impl Iterator; + } + } + + fn new(bbnsm: &'a BBNeedsSourcesMap) -> Self { + Self { + parent_source_map: Default::default(), + needs: bbnsm, + worklist: vec![], + } + } + + fn do_dataflow_block( + &mut self, + hugr: &mut impl HugrMut, + node: Node, + sources: Vec<(Wire, Type)>, + ) { + let types = sources.iter().map(|x| x.1.clone()).collect_vec(); + let new_sum_row_prefixes = { + let mut dfb = hugr.get_optype(node).as_dataflow_block().unwrap().clone(); + let mut nsrp = vec![vec![]; dfb.sum_rows.len()]; + dfb.inputs.to_mut().extend(types.clone()); + for (this_p, succ_n) in hugr.node_outputs(node).filter_map(|out_p| { + let (succ_n, _) = hugr.single_linked_input(node, out_p).unwrap(); + if hugr.get_optype(succ_n).is_exit_block() { + None } else { - Either::Right(all_inputs.into_iter().dropping_back(types.len())) + Some((out_p.index(), succ_n)) } + }) { + let succ_needs = &self.needs[&succ_n]; + let new_tys = succ_needs + .iter() + .map(|(&w, ty)| { + ( + sources.iter().find_position(|(x, _)| x == &w).unwrap().0, + ty.clone(), + ) + }) + .collect_vec(); + nsrp[this_p] = new_tys.clone(); + let tys = dfb.sum_rows[this_p].to_mut(); + let old_tys = mem::replace(tys, new_tys.into_iter().map(|x| x.1).collect_vec()); + tys.extend(old_tys); + } + hugr.replace_op(node, dfb).unwrap(); + nsrp + }; + + let input_wires = self + .thread_dataflow_parent(hugr, node, 0, sources.clone()) + .collect_vec(); + + let [_, o] = hugr.get_io(node).unwrap(); + let (cond, new_control_type) = { + let Some(EdgeKind::Value(control_type)) = + hugr.get_optype(o).port_kind(IncomingPort::from(0)) + else { + panic!("impossible") + }; + let Some(sum_type) = control_type.as_sum_type() else { + panic!("impossible") }; - let case_outputs = case - .add_dataflow_op(Tag::new(i, new_sum_rows.clone()), inputs) - .unwrap() - .outputs(); - case.finish_with_outputs(case_outputs).unwrap(); + let old_sum_rows: Vec = sum_type + .iter_variants() + .map(|x| x.clone().try_into().unwrap()) + .collect_vec(); + let new_sum_rows: Vec = + itertools::zip_eq(new_sum_row_prefixes.clone(), old_sum_rows.iter()) + .map(|(new, old)| { + new.into_iter() + .map(|x| x.1) + .chain(old.iter().cloned()) + .collect_vec() + .into() + }) + .collect_vec(); + + let new_control_type = Type::new_sum(new_sum_rows.clone()); + let mut cond = ConditionalBuilder::new( + old_sum_rows.clone(), + types.clone(), + new_control_type.clone(), + ) + .unwrap(); + for (i, row) in new_sum_row_prefixes.iter().enumerate() { + let mut case = cond.case_builder(i).unwrap(); + let case_inputs = case.input_wires().collect_vec(); + let mut args = vec![]; + for (source_i, _) in row { + args.push(case_inputs[old_sum_rows[i].len() + source_i]); + } + + args.extend(&case_inputs[..old_sum_rows[i].len()]); + + let case_outputs = case + .add_dataflow_op(Tag::new(i, new_sum_rows.clone()), args) + .unwrap() + .outputs(); + case.finish_with_outputs(case_outputs).unwrap(); + } + (cond.finish_hugr().unwrap(), new_control_type) + }; + let cond_node = hugr.insert_hugr(node, cond).new_root; + let (n, p) = hugr.single_linked_output(o, 0).unwrap(); + hugr.connect(n, p, cond_node, 0); + for (i, w) in input_wires.into_iter().enumerate() { + hugr.connect(w.node(), w.source(), cond_node, i + 1); } - (cond.finish_hugr().unwrap(), new_control_type) - }; - let cond_node = hugr.insert_hugr(node, cond).new_root; - let (n, p) = hugr.single_linked_output(o, 0).unwrap(); - hugr.connect(n, p, cond_node, 0); - for (i, w) in input_wires.into_iter().enumerate() { - hugr.connect(w.node(), w.source(), cond_node, i + 1); + hugr.disconnect(o, IncomingPort::from(0)); + hugr.connect(cond_node, 0, o, 0); + let mut output = hugr.get_optype(o).as_output().unwrap().clone(); + output.types.to_mut()[0] = new_control_type; + hugr.replace_op(o, output).unwrap(); } - hugr.disconnect(o, IncomingPort::from(0)); - hugr.connect(cond_node, 0, o, 0); - let mut output = hugr.get_optype(o).as_output().unwrap().clone(); - output.types.to_mut()[0] = new_control_type; - hugr.replace_op(o, output).unwrap(); - mk_workitems(node, sources, tailloop_ports, types) -} -#[derive(Clone, Default, Debug)] -struct ParentSourceMap(HashMap>); + fn do_cfg(&mut self, hugr: &mut impl HugrMut, node: Node, sources: Vec<(Wire, Type)>) { + let types = sources.iter().map(|x| x.1.clone()).collect_vec(); + { + let mut cfg = hugr.get_optype(node).as_cfg().unwrap().clone(); + let inputs = cfg.signature.input.to_mut(); + let old_inputs = mem::replace(inputs, types); + inputs.extend(old_inputs); + hugr.replace_op(node, cfg).unwrap(); + } + let new_cond_ports = (0..sources.len()) + .map(|i| hugr.insert_incoming_port(node, i)) + .collect_vec(); + self.insert_sources_in_parent(node, iter::empty()); + self.worklist + .extend(mk_workitems(node, sources, new_cond_ports)) + } -impl ParentSourceMap { - fn contains_parent(&self, parent: Node) -> bool { - self.0.contains_key(&parent) + fn do_dfg(&mut self, hugr: &mut impl HugrMut, node: Node, sources: Vec<(Wire, Type)>) { + let mut dfg = hugr.get_optype(node).as_dfg().unwrap().clone(); + let start_new_port_index = dfg.signature.input().len(); + let new_dfg_ports = (0..sources.len()) + .map(|i| hugr.insert_incoming_port(node, start_new_port_index + i)) + .collect_vec(); + dfg.signature + .input + .to_mut() + .extend(sources.iter().map(|x| x.1.clone())); + hugr.replace_op(node, dfg).unwrap(); + let _ = + self.thread_dataflow_parent(hugr, node, start_new_port_index, sources.iter().cloned()); + self.worklist + .extend(mk_workitems(node, sources, new_dfg_ports)); } - fn insert(&mut self, parent: Node, sources: impl IntoIterator) { - debug_assert!(!self.0.contains_key(&parent)); - self.0.entry(parent).or_default().extend(sources); + fn do_conditional(&mut self, hugr: &mut impl HugrMut, node: Node, sources: Vec<(Wire, Type)>) { + let mut cond = hugr.get_optype(node).as_conditional().unwrap().clone(); + let start_new_port_index = cond.signature().input().len(); + cond.other_inputs + .to_mut() + .extend(sources.iter().map(|x| x.1.clone())); + hugr.replace_op(node, cond).unwrap(); + let new_cond_ports = (0..sources.len()) + .map(|i| hugr.insert_incoming_port(node, start_new_port_index + i)) + .collect_vec(); + self.insert_sources_in_parent(node, iter::empty()); + self.worklist + .extend(mk_workitems(node, sources, new_cond_ports)) } - fn get(&self, parent: Node, source: Wire) -> Option { - self.0.get(&parent).and_then(|m| m.get(&source).cloned()) + fn do_case(&mut self, hugr: &mut impl HugrMut, node: Node, sources: Vec<(Wire, Type)>) { + let mut case = hugr.get_optype(node).as_case().unwrap().clone(); + let start_case_port_index = case.signature.input().len(); + case.signature + .input + .to_mut() + .extend(sources.iter().map(|x| x.1.clone())); + hugr.replace_op(node, case).unwrap(); + let _ = self.thread_dataflow_parent(hugr, node, start_case_port_index, sources); + } + + fn do_tailloop(&mut self, hugr: &mut impl HugrMut, node: Node, sources: Vec<(Wire, Type)>) { + let mut tailloop = hugr.get_optype(node).as_tail_loop().unwrap().clone(); + let types = sources.iter().map(|x| x.1.clone()).collect_vec(); + let start_port_index = tailloop.just_inputs.len(); + { + tailloop.just_inputs.to_mut().extend(types.clone()); + hugr.replace_op(node, tailloop).unwrap(); + } + let tailloop_ports = (0..sources.len()) + .map(|i| hugr.insert_incoming_port(node, start_port_index + i)) + .collect_vec(); + + let input_wires = self + .thread_dataflow_parent(hugr, node, start_port_index, sources.clone()) + .collect_vec(); + + let [_, o] = hugr.get_io(node).unwrap(); + let (cond, new_control_type) = { + let Some(EdgeKind::Value(control_type)) = + hugr.get_optype(o).port_kind(IncomingPort::from(0)) + else { + panic!("impossible") + }; + let Some(sum_type) = control_type.as_sum_type() else { + panic!("impossible") + }; + + let old_sum_rows: Vec = sum_type + .iter_variants() + .map(|x| x.clone().try_into().unwrap()) + .collect_vec(); + let new_sum_rows = { + let mut v = old_sum_rows.clone(); + v[TailLoop::CONTINUE_TAG] + .to_mut() + .extend(types.iter().cloned()); + v + }; + + let new_control_type = Type::new_sum(new_sum_rows.clone()); + let mut cond = + ConditionalBuilder::new(old_sum_rows, types.clone(), new_control_type.clone()) + .unwrap(); + for i in 0..2 { + let mut case = cond.case_builder(i).unwrap(); + let inputs = { + let all_inputs = case.input_wires(); + if i == TailLoop::CONTINUE_TAG { + Either::Left(all_inputs) + } else { + Either::Right(all_inputs.into_iter().dropping_back(types.len())) + } + }; + + let case_outputs = case + .add_dataflow_op(Tag::new(i, new_sum_rows.clone()), inputs) + .unwrap() + .outputs(); + case.finish_with_outputs(case_outputs).unwrap(); + } + (cond.finish_hugr().unwrap(), new_control_type) + }; + let cond_node = hugr.insert_hugr(node, cond).new_root; + let (n, p) = hugr.single_linked_output(o, 0).unwrap(); + hugr.connect(n, p, cond_node, 0); + for (i, w) in input_wires.into_iter().enumerate() { + hugr.connect(w.node(), w.source(), cond_node, i + 1); + } + hugr.disconnect(o, IncomingPort::from(0)); + hugr.connect(cond_node, 0, o, 0); + let mut output = hugr.get_optype(o).as_output().unwrap().clone(); + output.types.to_mut()[0] = new_control_type; + hugr.replace_op(o, output).unwrap(); + self.worklist + .extend(mk_workitems(node, sources, tailloop_ports)) + } + + fn finish(self, _hugr: &mut impl HugrMut) -> (Vec, ParentSourceMap) { + (self.worklist, self.parent_source_map) } } +fn thread_sources( + hugr: &mut impl HugrMut, + bb_needs_sources_map: &BBNeedsSourcesMap, +) -> (Vec, ParentSourceMap) { + let mut state = ThreadState::new(bb_needs_sources_map); + for (&bb, sources) in bb_needs_sources_map { + let sources = sources + .iter() + .map(|(&w, ty)| (w, ty.clone())) + .collect_vec(); + match hugr.get_optype(bb).clone() { + OpType::DFG(_) => state.do_dfg(hugr, bb, sources), + OpType::Conditional(_) => state.do_conditional(hugr, bb, sources), + OpType::Case(_) => state.do_case(hugr, bb, sources), + OpType::TailLoop(_) => state.do_tailloop(hugr, bb, sources), + OpType::DataflowBlock(_) => state.do_dataflow_block(hugr, bb, sources), + OpType::CFG(_) => state.do_cfg(hugr, bb, sources), + _ => panic!("impossible"), + } + } + + state.finish(hugr) +} + fn mk_workitems( node: Node, - sources: impl IntoIterator, + sources: impl IntoIterator, ports: impl IntoIterator, - types: impl IntoIterator, ) -> impl Iterator { - itertools::izip!(sources, ports, types).map(move |(source, p, ty)| WorkItem { + itertools::izip!(sources, ports).map(move |((source, ty), p)| WorkItem { source, target: (node, p), ty, }) } -fn thread_sources( - parent_source_map: &mut ParentSourceMap, - hugr: &mut impl HugrMut, - bb: Node, - sources: impl IntoIterator, -) -> Vec { - let (source_wires, types): (Vec<_>, Vec<_>) = sources.into_iter().unzip(); - match hugr.get_optype(bb).clone() { - OpType::DFG(mut dfg) => { - debug_assert!(!parent_source_map.contains_parent(bb)); - let start_new_port_index = dfg.signature.input().len(); - let new_dfg_ports = (0..source_wires.len()) - .map(|i| hugr.insert_incoming_port(bb, start_new_port_index + i)) - .collect_vec(); - dfg.signature.input.to_mut().extend(types.clone()); - hugr.replace_op(bb, dfg).unwrap(); - for (source, &target) in iter::zip(source_wires.iter(), new_dfg_ports.iter()) { - hugr.connect(source.node(), source.source(), bb, target); - } - parent_source_map.insert( - bb, - iter::zip( - source_wires.iter().copied(), - thread_dataflow_parent(hugr, bb, start_new_port_index, types.clone()), - ), - ); - mk_workitems(bb, source_wires, new_dfg_ports, types).collect_vec() - } - OpType::Conditional(mut cond) => { - debug_assert!(!parent_source_map.contains_parent(bb)); - let start_new_port_index = cond.signature().input().len(); - cond.other_inputs.to_mut().extend(types.clone()); - hugr.replace_op(bb, cond).unwrap(); - let new_cond_ports = (0..source_wires.len()) - .map(|i| hugr.insert_incoming_port(bb, start_new_port_index + i)) - .collect_vec(); - parent_source_map.insert(bb, iter::empty()); - mk_workitems(bb, source_wires, new_cond_ports, types).collect_vec() - } - OpType::Case(mut case) => { - debug_assert!(!parent_source_map.contains_parent(bb)); - let start_case_port_index = case.signature.input().len(); - case.signature.input.to_mut().extend(types.clone()); - hugr.replace_op(bb, case).unwrap(); - parent_source_map.insert( - bb, - iter::zip( - source_wires.iter().copied(), - thread_dataflow_parent(hugr, bb, start_case_port_index, types), - ), - ); - vec![] - } - OpType::TailLoop(_) => { - do_tailloop(parent_source_map, hugr, bb, iter::zip(source_wires, types)).collect_vec() - } - _ => panic!("impossible"), - } - // _ => panic!("impossible"), - // }; - // non_local_edges.push(workitem); - // wire - // }; - // hugr.disconnect(target.0, target.1); - // hugr.connect( - // local_source.node(), - // local_source.source(), - // target.0, - // target.1, - // ); - // } -} +type BBNeedsSourcesMap = HashMap>; #[derive(Debug, Default, Clone)] -struct BBNeedsSourcesMapBuilder(HashMap>); +struct BBNeedsSourcesMapBuilder(BBNeedsSourcesMap); impl BBNeedsSourcesMapBuilder { fn insert(&mut self, bb: Node, source: Wire, ty: Type) { @@ -320,7 +476,8 @@ impl BBNeedsSourcesMapBuilder { .get(&child) .into_iter() .flat_map(move |m| { - m.iter().filter(move |(w, _)| hugr.get_parent(w.node()).unwrap() != parent) + m.iter() + .filter(move |(w, _)| hugr.get_parent(w.node()).unwrap() != parent) .map(|(&w, ty)| (w, ty.clone())) }) .collect_vec(); @@ -331,15 +488,15 @@ impl BBNeedsSourcesMapBuilder { any } - fn finish(mut self, hugr: impl HugrView) -> HashMap> { - let conds = self - .0 - .keys() - .copied() - .filter(|&n| hugr.get_optype(n).is_conditional()) - .collect_vec(); - for cond in conds { - if hugr.get_optype(cond).is_conditional() { + fn finish(mut self, hugr: impl HugrView) -> BBNeedsSourcesMap { + { + let conds = self + .0 + .keys() + .copied() + .filter(|&n| hugr.get_optype(n).is_conditional()) + .collect_vec(); + for cond in conds { let cases = hugr .children(cond) .filter(|&child| hugr.get_optype(child).is_case()) @@ -358,6 +515,40 @@ impl BBNeedsSourcesMapBuilder { } } } + { + let cfgs = self + .0 + .keys() + .copied() + .filter(|&n| hugr.get_optype(n).is_cfg() && self.0.contains_key(&n)) + .collect_vec(); + for cfg in cfgs { + let dfbs = hugr + .children(cfg) + .filter(|&child| hugr.get_optype(child).is_dataflow_block()) + .collect_vec(); + + // let mut dfb_needs_map: HashMap<_, _> = dfbs + // .iter() + // .map(|&n| (n, self.0.get(&n).cloned().unwrap_or_default())) + // .collect(); + loop { + let mut any_change = false; + for &dfb in dfbs.iter() { + for succ_n in hugr.output_neighbours(dfb) { + for (w, ty) in self.0.get(&succ_n).cloned().unwrap_or_default() { + any_change |= + self.0.entry(dfb).or_default().insert(w, ty).is_none(); + } + } + } + if !any_change { + break; + } + } + } + } + self.0 } } @@ -392,10 +583,7 @@ pub fn remove_nonlocal_edges( let bb_needs_sources_map = { let nonlocal_sorted = { let mut v = iter::successors(Some(vec![root]), |nodes| { - let children = nodes - .iter() - .flat_map(|&n| hugr.children(n)) - .collect_vec(); + let children = nodes.iter().flat_map(|&n| hugr.children(n)).collect_vec(); (!children.is_empty()).then_some(children) }) .flatten() @@ -450,26 +638,21 @@ pub fn remove_nonlocal_edges( } } - let mut worklist = nonlocal_edges_map.into_values().collect_vec(); - let mut parent_source_map = ParentSourceMap::default(); - - for (bb, needs_sources) in bb_needs_sources_map { - worklist.extend(thread_sources( - &mut parent_source_map, - hugr, - bb, - needs_sources, - )); - } - - let parent_source_map = parent_source_map; + let (parent_source_map, worklist) = { + let mut worklist = nonlocal_edges_map.into_values().collect_vec(); + let (wl, psm) = thread_sources(hugr, &bb_needs_sources_map); + worklist.extend(wl); + (psm, worklist) + }; - while let Some(wi) = worklist.pop() { + for wi in worklist { let parent = hugr.get_parent(wi.target.0).unwrap(); let source = if hugr.get_parent(wi.source.node()).unwrap() == parent { wi.source } else { - parent_source_map.get(parent, wi.source).unwrap() + parent_source_map + .get_source_in_parent(parent, wi.source) + .unwrap() }; debug_assert_eq!(hugr.get_parent(source.node()), hugr.get_parent(wi.target.0)); hugr.disconnect(wi.target.0, wi.target.1); @@ -537,7 +720,7 @@ mod test { } #[test] - fn remove_nonlocal_edges_dfg() { + fn dfg() { let mut hugr = { let mut outer = DFGBuilder::new(Signature::new_endo(bool_t())).unwrap(); let [w0] = outer.input_wires_arr(); @@ -557,7 +740,7 @@ mod test { } #[test] - fn remove_nonlocal_edges_tailloop() { + fn tailloop() { let (t1, t2, t3) = (Type::UNIT, bool_t(), Type::new_unit_sum(3)); let mut hugr = { let mut outer = DFGBuilder::new(Signature::new_endo(vec![ @@ -601,7 +784,7 @@ mod test { } #[test] - fn remove_nonlocal_edges_cond() { + fn conditional() { let (t1, t2, t3) = (Type::UNIT, bool_t(), Type::new_unit_sum(3)); let out_variants = vec![t1.clone().into(), t2.clone().into()]; let out_type = Type::new_sum(out_variants.clone()); @@ -652,4 +835,47 @@ mod test { hugr.validate().unwrap_or_else(|e| panic!("{e}")); assert!(ensure_no_nonlocal_edges(&hugr).is_ok()); } + + #[test] + fn cfg() { + let (t1, t2, t3) = (Type::UNIT, bool_t(), Type::new_unit_sum(3)); + // let out_variants = vec![t1.clone().into(), t2.clone().into()]; + let out_type = t1.clone(); + let mut hugr = { + let mut outer = DFGBuilder::new(Signature::new( + vec![t1.clone(), t2.clone(), t3.clone()], + out_type.clone(), + )) + .unwrap(); + let [s1, s2, s3] = outer.input_wires_arr(); + let [out] = { + let mut cfg = outer.cfg_builder([], out_type.into()).unwrap(); + + let entry = { + let mut entry = cfg.entry_builder([type_row![]], type_row![]).unwrap(); + let w = entry.add_load_value(Value::unit()); + entry.finish_with_outputs(w, []).unwrap() + }; + let exit = cfg.exit_block(); + + let bb1 = { + let mut entry = cfg + .block_builder(type_row![], [type_row![]], t1.clone().into()) + .unwrap(); + let w = entry.add_load_value(Value::unit()); + entry.finish_with_outputs(w, [s1]).unwrap() + }; + cfg.branch(&entry, 0, &bb1).unwrap(); + cfg.branch(&bb1, 0, &exit).unwrap(); + cfg.finish_sub_container().unwrap().outputs_arr() + }; + outer.finish_hugr_with_outputs([out]).unwrap() + }; + assert!(ensure_no_nonlocal_edges(&hugr).is_err()); + let root = hugr.root(); + remove_nonlocal_edges(&mut hugr, root).unwrap(); + println!("{}", hugr.mermaid_string()); + hugr.validate().unwrap_or_else(|e| panic!("{e}")); + assert!(ensure_no_nonlocal_edges(&hugr).is_ok()); + } } From 8f0b164b693e62448573e8467c100f15f6de1fae Mon Sep 17 00:00:00 2001 From: Douglas Wilson Date: Sat, 8 Feb 2025 14:19:48 +0000 Subject: [PATCH 08/23] wip --- hugr-passes/src/non_local.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/hugr-passes/src/non_local.rs b/hugr-passes/src/non_local.rs index 22bef220e8..09feb42c14 100644 --- a/hugr-passes/src/non_local.rs +++ b/hugr-passes/src/non_local.rs @@ -720,7 +720,7 @@ mod test { } #[test] - fn dfg() { + fn unnonlocal_dfg() { let mut hugr = { let mut outer = DFGBuilder::new(Signature::new_endo(bool_t())).unwrap(); let [w0] = outer.input_wires_arr(); @@ -740,7 +740,7 @@ mod test { } #[test] - fn tailloop() { + fn unnonlocal_tailloop() { let (t1, t2, t3) = (Type::UNIT, bool_t(), Type::new_unit_sum(3)); let mut hugr = { let mut outer = DFGBuilder::new(Signature::new_endo(vec![ @@ -784,7 +784,7 @@ mod test { } #[test] - fn conditional() { + fn unnonlocal_conditional() { let (t1, t2, t3) = (Type::UNIT, bool_t(), Type::new_unit_sum(3)); let out_variants = vec![t1.clone().into(), t2.clone().into()]; let out_type = Type::new_sum(out_variants.clone()); @@ -837,7 +837,7 @@ mod test { } #[test] - fn cfg() { + fn unnonlocal_cfg() { let (t1, t2, t3) = (Type::UNIT, bool_t(), Type::new_unit_sum(3)); // let out_variants = vec![t1.clone().into(), t2.clone().into()]; let out_type = t1.clone(); From af2959a77b4cbeb4d0042b94dbab56640b6a3edb Mon Sep 17 00:00:00 2001 From: Douglas Wilson Date: Sat, 8 Feb 2025 14:26:39 +0000 Subject: [PATCH 09/23] feat: Add `Type::as_sum` and `SumType::variants`. --- .../std_extensions/arithmetic/float_types.rs | 1 + .../std_extensions/arithmetic/int_types.rs | 1 + .../src/std_extensions/collections/array.rs | 1 + hugr-core/src/types.rs | 31 +++++++++++++++---- 4 files changed, 28 insertions(+), 6 deletions(-) diff --git a/hugr-core/src/std_extensions/arithmetic/float_types.rs b/hugr-core/src/std_extensions/arithmetic/float_types.rs index 304b940453..579d89e6bb 100644 --- a/hugr-core/src/std_extensions/arithmetic/float_types.rs +++ b/hugr-core/src/std_extensions/arithmetic/float_types.rs @@ -65,6 +65,7 @@ impl std::ops::Deref for ConstF64 { impl ConstF64 { /// Name of the constructor for creating constant 64bit floats. + #[cfg_attr(not(feature = "model_unstable"), allow(dead_code))] pub(crate) const CTR_NAME: &'static str = "arithmetic.float.const-f64"; /// Create a new [`ConstF64`] diff --git a/hugr-core/src/std_extensions/arithmetic/int_types.rs b/hugr-core/src/std_extensions/arithmetic/int_types.rs index 1342dd9320..e5d625695e 100644 --- a/hugr-core/src/std_extensions/arithmetic/int_types.rs +++ b/hugr-core/src/std_extensions/arithmetic/int_types.rs @@ -105,6 +105,7 @@ pub struct ConstInt { impl ConstInt { /// Name of the constructor for creating constant integers. + #[cfg_attr(not(feature = "model_unstable"), allow(dead_code))] pub(crate) const CTR_NAME: &'static str = "arithmetic.int.const"; /// Create a new [`ConstInt`] with a given width and unsigned value diff --git a/hugr-core/src/std_extensions/collections/array.rs b/hugr-core/src/std_extensions/collections/array.rs index 618bd61826..93f58727cb 100644 --- a/hugr-core/src/std_extensions/collections/array.rs +++ b/hugr-core/src/std_extensions/collections/array.rs @@ -43,6 +43,7 @@ pub struct ArrayValue { impl ArrayValue { /// Name of the constructor for creating constant arrays. + #[cfg_attr(not(feature = "model_unstable"), allow(dead_code))] pub(crate) const CTR_NAME: &'static str = "collections.array.const"; /// Create a new [CustomConst] for an array of values of type `typ`. diff --git a/hugr-core/src/types.rs b/hugr-core/src/types.rs index 962e00876c..c22c1fff8a 100644 --- a/hugr-core/src/types.rs +++ b/hugr-core/src/types.rs @@ -27,7 +27,7 @@ pub use type_row::{TypeRow, TypeRowRV}; pub(crate) use poly_func::PolyFuncTypeBase; use itertools::FoldWhile::{Continue, Done}; -use itertools::{repeat_n, Itertools}; +use itertools::{Either, Itertools as _}; #[cfg(test)] use proptest_derive::Arbitrary; use serde::{Deserialize, Serialize}; @@ -189,7 +189,7 @@ impl std::fmt::Display for SumType { SumType::Unit { size: 1 } => write!(f, "Unit"), SumType::Unit { size: 2 } => write!(f, "Bool"), SumType::Unit { size } => { - display_list_with_separator(repeat_n("[]", *size as usize), f, "+") + display_list_with_separator(itertools::repeat_n("[]", *size as usize), f, "+") } SumType::General { rows } => match rows.len() { 1 if rows[0].is_empty() => write!(f, "Unit"), @@ -216,17 +216,17 @@ impl SumType { } } - /// New UnitSum with empty Tuple variants + /// New UnitSum with empty Tuple variants. pub const fn new_unary(size: u8) -> Self { Self::Unit { size } } - /// New tuple (single row of variants) + /// New tuple (single row of variants). pub fn new_tuple(types: impl Into) -> Self { Self::new([types.into()]) } - /// New option type (either an empty option, or a row of types) + /// New option type (either an empty option, or a row of types). pub fn new_option(types: impl Into) -> Self { Self::new([vec![].into(), types.into()]) } @@ -248,7 +248,7 @@ impl SumType { } } - /// Returns variant row if there is only one variant + /// Returns variant row if there is only one variant. pub fn as_tuple(&self) -> Option<&TypeRowRV> { match self { SumType::Unit { size } if *size == 1 => Some(TypeRV::EMPTY_TYPEROW_REF), @@ -256,6 +256,17 @@ impl SumType { _ => None, } } + + /// Returns an iterator over the variants. + pub fn variants(&self) -> impl Iterator { + match self { + SumType::Unit { size } => Either::Left(itertools::repeat_n( + TypeRV::EMPTY_TYPEROW_REF, + *size as usize, + )), + SumType::General { rows } => Either::Right(rows.iter()), + } + } } impl From for TypeBase { @@ -453,6 +464,14 @@ impl TypeBase { &mut self.0 } + /// Returns the inner [SumType] if the type is a sum. + pub fn as_sum(&self) -> Option<&SumType> { + match &self.0 { + TypeEnum::Sum(s) => Some(s), + _ => None, + } + } + /// Report if the type is copyable - i.e.the least upper bound of the type /// is contained by the copyable bound. pub const fn copyable(&self) -> bool { From 2e3e282ef8aff7c5c0cd76c34b2b187839622a2a Mon Sep 17 00:00:00 2001 From: Douglas Wilson Date: Sun, 9 Feb 2025 09:04:42 +0000 Subject: [PATCH 10/23] coverage --- hugr-core/src/types.rs | 26 +++++++++++++++++++++++++- 1 file changed, 25 insertions(+), 1 deletion(-) diff --git a/hugr-core/src/types.rs b/hugr-core/src/types.rs index c22c1fff8a..fa10e2a609 100644 --- a/hugr-core/src/types.rs +++ b/hugr-core/src/types.rs @@ -732,13 +732,37 @@ pub(crate) mod test { assert_eq!(pred1, Type::from(pred_direct)); } + #[test] + fn as_sum() { + let t = Type::new_unit_sum(0); + assert!(t.as_sum().is_some()); + } + + #[test] + fn sum_variants() { + { + let variants: Vec = vec![ + TypeRV::UNIT.into(), + vec![TypeRV::new_row_var_use(0, TypeBound::Any)].into(), + ]; + let t = SumType::new(variants.clone()); + assert_eq!(variants, t.variants().cloned().collect_vec()); + } + { + assert_eq!( + vec![&TypeRV::EMPTY_TYPEROW;3], + SumType::new_unary(3).variants().collect_vec() + ); + } + } + mod proptest { use crate::proptest::RecursionDepth; use super::{AliasDecl, MaybeRV, TypeBase, TypeBound, TypeEnum}; use crate::types::{CustomType, FuncValueType, SumType, TypeRowRV}; - use ::proptest::prelude::*; + use proptest::prelude::*; impl Arbitrary for super::SumType { type Parameters = RecursionDepth; From dc95bc972666b3823ad3d134ac26891e32fd3ff8 Mon Sep 17 00:00:00 2001 From: Douglas Wilson Date: Sun, 9 Feb 2025 09:48:27 +0000 Subject: [PATCH 11/23] wip --- hugr-passes/src/non_local.rs | 84 +++++++++++++++++++++++++++--------- 1 file changed, 63 insertions(+), 21 deletions(-) diff --git a/hugr-passes/src/non_local.rs b/hugr-passes/src/non_local.rs index 09feb42c14..71b127e89e 100644 --- a/hugr-passes/src/non_local.rs +++ b/hugr-passes/src/non_local.rs @@ -224,12 +224,12 @@ impl<'a> ThreadState<'a> { else { panic!("impossible") }; - let Some(sum_type) = control_type.as_sum_type() else { + let Some(sum_type) = control_type.as_sum() else { panic!("impossible") }; let old_sum_rows: Vec = sum_type - .iter_variants() + .variants() .map(|x| x.clone().try_into().unwrap()) .collect_vec(); let new_sum_rows: Vec = @@ -364,12 +364,12 @@ impl<'a> ThreadState<'a> { else { panic!("impossible") }; - let Some(sum_type) = control_type.as_sum_type() else { + let Some(sum_type) = control_type.as_sum() else { panic!("impossible") }; let old_sum_rows: Vec = sum_type - .iter_variants() + .variants() .map(|x| x.clone().try_into().unwrap()) .collect_vec(); let new_sum_rows = { @@ -665,8 +665,8 @@ pub fn remove_nonlocal_edges( #[cfg(test)] mod test { use hugr_core::{ - builder::{DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, SubContainer}, - extension::prelude::{bool_t, Noop}, + builder::{Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, SubContainer}, + extension::prelude::{bool_t, either_type, option_type, Noop}, ops::{handle::NodeHandle, Tag, TailLoop, Value}, type_row, types::Signature, @@ -841,35 +841,77 @@ mod test { let (t1, t2, t3) = (Type::UNIT, bool_t(), Type::new_unit_sum(3)); // let out_variants = vec![t1.clone().into(), t2.clone().into()]; let out_type = t1.clone(); + // Cfg consists of 4 dataflow blocks and an exit block + // + // The 4 dataflow blocks form a diamond, and the bottom block branches + // either to the entry block or the exit block. + // + // Two non-local uses in the left block means that these values must + // be threaded through all blocks, because of the loop. + // + // All non-trivial(i.e. more than one choice of successor) branching is + // done on an option type to exercise both empty and occupied control + // sums. + // + // All branches have an other-output. let mut hugr = { + let branch_sum_type = either_type(Type::UNIT, Type::UNIT); + let branch_type = Type::from(branch_sum_type.clone()); + let branch_variants = branch_sum_type.variants().cloned().map(|x| x.try_into().unwrap()).collect_vec(); + let nonlocal1_type = bool_t(); + let nonlocal2_type = Type::new_unit_sum(3); + let other_output_type = branch_type.clone(); let mut outer = DFGBuilder::new(Signature::new( - vec![t1.clone(), t2.clone(), t3.clone()], - out_type.clone(), + vec![branch_type.clone(), nonlocal1_type.clone(), nonlocal2_type.clone(), Type::UNIT], + vec![Type::UNIT, other_output_type.clone()] )) .unwrap(); - let [s1, s2, s3] = outer.input_wires_arr(); - let [out] = { - let mut cfg = outer.cfg_builder([], out_type.into()).unwrap(); + let [b, nl1, nl2, unit] = outer.input_wires_arr(); + let [unit, out] = { + let mut cfg = outer.cfg_builder([(Type::UNIT, unit), (branch_type.clone(), b)], vec![Type::UNIT, other_output_type.clone()].into()).unwrap(); let entry = { - let mut entry = cfg.entry_builder([type_row![]], type_row![]).unwrap(); - let w = entry.add_load_value(Value::unit()); - entry.finish_with_outputs(w, []).unwrap() + let entry = cfg.entry_builder(branch_variants.clone(), other_output_type.clone().into()).unwrap(); + let [_, b] = entry.input_wires_arr(); + + entry.finish_with_outputs(b, [b]).unwrap() }; let exit = cfg.exit_block(); - let bb1 = { + let bb_left = { let mut entry = cfg - .block_builder(type_row![], [type_row![]], t1.clone().into()) + .block_builder(vec![Type::UNIT, other_output_type.clone()].into(), [type_row![]], other_output_type.clone().into()) + .unwrap(); + let [unit, oo] = entry.input_wires_arr(); + let [_] = entry.add_dataflow_op(Noop::new(nonlocal1_type), [nl1]).unwrap().outputs_arr(); + let [_] = entry.add_dataflow_op(Noop::new(nonlocal2_type), [nl2]).unwrap().outputs_arr(); + entry.finish_with_outputs(unit, [oo]).unwrap() + }; + + let bb_right = { + let entry = cfg + .block_builder(vec![Type::UNIT, other_output_type.clone()].into(), [type_row![]], other_output_type.clone().into()) + .unwrap(); + let [b, oo] = entry.input_wires_arr(); + entry.finish_with_outputs(unit, [oo]).unwrap() + }; + + let bb_bottom = { + let entry = cfg + .block_builder(branch_type.clone().into(), branch_variants, other_output_type.clone().into()) .unwrap(); - let w = entry.add_load_value(Value::unit()); - entry.finish_with_outputs(w, [s1]).unwrap() + let [oo] = entry.input_wires_arr(); + entry.finish_with_outputs(oo, [oo]).unwrap() }; - cfg.branch(&entry, 0, &bb1).unwrap(); - cfg.branch(&bb1, 0, &exit).unwrap(); + cfg.branch(&entry, 0, &bb_left).unwrap(); + cfg.branch(&entry, 1, &bb_right).unwrap(); + cfg.branch(&bb_left, 0, &bb_bottom).unwrap(); + cfg.branch(&bb_right, 0, &bb_bottom).unwrap(); + cfg.branch(&bb_bottom, 0, &entry).unwrap(); + cfg.branch(&bb_bottom, 1, &exit).unwrap(); cfg.finish_sub_container().unwrap().outputs_arr() }; - outer.finish_hugr_with_outputs([out]).unwrap() + outer.finish_hugr_with_outputs([unit, out]).unwrap() }; assert!(ensure_no_nonlocal_edges(&hugr).is_err()); let root = hugr.root(); From 9f619b47f5a145e0533439cc9dfbdae9b56314f5 Mon Sep 17 00:00:00 2001 From: Douglas Wilson Date: Mon, 10 Feb 2025 09:37:33 +0000 Subject: [PATCH 12/23] wip --- hugr-core/src/types.rs | 2 +- hugr-passes/src/non_local.rs | 123 +++++++++++++++++++++-------------- 2 files changed, 76 insertions(+), 49 deletions(-) diff --git a/hugr-core/src/types.rs b/hugr-core/src/types.rs index fa10e2a609..d69c7ef7d6 100644 --- a/hugr-core/src/types.rs +++ b/hugr-core/src/types.rs @@ -750,7 +750,7 @@ pub(crate) mod test { } { assert_eq!( - vec![&TypeRV::EMPTY_TYPEROW;3], + vec![&TypeRV::EMPTY_TYPEROW; 3], SumType::new_unary(3).variants().collect_vec() ); } diff --git a/hugr-passes/src/non_local.rs b/hugr-passes/src/non_local.rs index 71b127e89e..8600f19c50 100644 --- a/hugr-passes/src/non_local.rs +++ b/hugr-passes/src/non_local.rs @@ -155,7 +155,7 @@ impl<'a> ThreadState<'a> { delegate! { to self.parent_source_map { // fn contains_parent(&self, parent: Node) -> bool; - fn get_source_in_parent(&self, parent: Node, source: Wire) -> Option; + // fn get_source_in_parent(&self, parent: Node, source: Wire) -> Option; fn insert_sources_in_parent(&mut self, parent: Node, sources: impl IntoIterator); fn thread_dataflow_parent( &mut self, @@ -183,33 +183,24 @@ impl<'a> ThreadState<'a> { ) { let types = sources.iter().map(|x| x.1.clone()).collect_vec(); let new_sum_row_prefixes = { - let mut dfb = hugr.get_optype(node).as_dataflow_block().unwrap().clone(); - let mut nsrp = vec![vec![]; dfb.sum_rows.len()]; - dfb.inputs.to_mut().extend(types.clone()); + let mut this_dfb = hugr.get_optype(node).as_dataflow_block().unwrap().clone(); + let mut nsrp = vec![vec![]; this_dfb.sum_rows.len()]; + vec_prepend(this_dfb.inputs.to_mut(), types.clone()); + for (this_p, succ_n) in hugr.node_outputs(node).filter_map(|out_p| { let (succ_n, _) = hugr.single_linked_input(node, out_p).unwrap(); - if hugr.get_optype(succ_n).is_exit_block() { - None - } else { - Some((out_p.index(), succ_n)) - } + hugr.get_optype(succ_n).is_dataflow_block().then_some((out_p.index(), succ_n)) }) { let succ_needs = &self.needs[&succ_n]; - let new_tys = succ_needs + let succ_needs_source_indices = succ_needs .iter() - .map(|(&w, ty)| { - ( - sources.iter().find_position(|(x, _)| x == &w).unwrap().0, - ty.clone(), - ) - }) + .map(|(&w, _)| sources.iter().find_position(|(x, _)| x == &w).unwrap().0) .collect_vec(); - nsrp[this_p] = new_tys.clone(); - let tys = dfb.sum_rows[this_p].to_mut(); - let old_tys = mem::replace(tys, new_tys.into_iter().map(|x| x.1).collect_vec()); - tys.extend(old_tys); + let succ_needs_tys = succ_needs_source_indices.iter().copied().map(|x| sources[x].1.clone()).collect_vec(); + vec_prepend(this_dfb.sum_rows[this_p].to_mut(), succ_needs_tys); + nsrp[this_p] = succ_needs_source_indices; } - hugr.replace_op(node, dfb).unwrap(); + hugr.replace_op(node, this_dfb).unwrap(); nsrp }; @@ -233,11 +224,11 @@ impl<'a> ThreadState<'a> { .map(|x| x.clone().try_into().unwrap()) .collect_vec(); let new_sum_rows: Vec = - itertools::zip_eq(new_sum_row_prefixes.clone(), old_sum_rows.iter()) - .map(|(new, old)| { - new.into_iter() - .map(|x| x.1) - .chain(old.iter().cloned()) + itertools::zip_eq(new_sum_row_prefixes.iter(), old_sum_rows.iter()) + .map(|(new_source_indices, old_tys)| { + new_source_indices.into_iter() + .map(|&x| sources[x].1.clone()) + .chain(old_tys.iter().cloned()) .collect_vec() .into() }) @@ -250,11 +241,11 @@ impl<'a> ThreadState<'a> { new_control_type.clone(), ) .unwrap(); - for (i, row) in new_sum_row_prefixes.iter().enumerate() { + for (i, new_source_indices) in new_sum_row_prefixes.into_iter().enumerate() { let mut case = cond.case_builder(i).unwrap(); let case_inputs = case.input_wires().collect_vec(); let mut args = vec![]; - for (source_i, _) in row { + for source_i in new_source_indices { args.push(case_inputs[old_sum_rows[i].len() + source_i]); } @@ -279,6 +270,7 @@ impl<'a> ThreadState<'a> { let mut output = hugr.get_optype(o).as_output().unwrap().clone(); output.types.to_mut()[0] = new_control_type; hugr.replace_op(o, output).unwrap(); + dbg!(hugr.single_linked_output(o, 0)); } fn do_cfg(&mut self, hugr: &mut impl HugrMut, node: Node, sources: Vec<(Wire, Type)>) { @@ -429,10 +421,7 @@ fn thread_sources( ) -> (Vec, ParentSourceMap) { let mut state = ThreadState::new(bb_needs_sources_map); for (&bb, sources) in bb_needs_sources_map { - let sources = sources - .iter() - .map(|(&w, ty)| (w, ty.clone())) - .collect_vec(); + let sources = sources.iter().map(|(&w, ty)| (w, ty.clone())).collect_vec(); match hugr.get_optype(bb).clone() { OpType::DFG(_) => state.do_dfg(hugr, bb, sources), OpType::Conditional(_) => state.do_conditional(hugr, bb, sources), @@ -662,11 +651,18 @@ pub fn remove_nonlocal_edges( Ok(()) } +fn vec_prepend(v: &mut Vec, ts: impl IntoIterator) { + let mut old_v = mem::replace(v, ts.into_iter().collect()); + v.extend(old_v.drain(..)); +} + #[cfg(test)] mod test { use hugr_core::{ - builder::{Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, SubContainer}, - extension::prelude::{bool_t, either_type, option_type, Noop}, + builder::{ + DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, SubContainer, + }, + extension::prelude::{bool_t, either_type, Noop}, ops::{handle::NodeHandle, Tag, TailLoop, Value}, type_row, types::Signature, @@ -838,9 +834,6 @@ mod test { #[test] fn unnonlocal_cfg() { - let (t1, t2, t3) = (Type::UNIT, bool_t(), Type::new_unit_sum(3)); - // let out_variants = vec![t1.clone().into(), t2.clone().into()]; - let out_type = t1.clone(); // Cfg consists of 4 dataflow blocks and an exit block // // The 4 dataflow blocks form a diamond, and the bottom block branches @@ -857,21 +850,37 @@ mod test { let mut hugr = { let branch_sum_type = either_type(Type::UNIT, Type::UNIT); let branch_type = Type::from(branch_sum_type.clone()); - let branch_variants = branch_sum_type.variants().cloned().map(|x| x.try_into().unwrap()).collect_vec(); + let branch_variants = branch_sum_type + .variants() + .cloned() + .map(|x| x.try_into().unwrap()) + .collect_vec(); let nonlocal1_type = bool_t(); let nonlocal2_type = Type::new_unit_sum(3); let other_output_type = branch_type.clone(); let mut outer = DFGBuilder::new(Signature::new( - vec![branch_type.clone(), nonlocal1_type.clone(), nonlocal2_type.clone(), Type::UNIT], - vec![Type::UNIT, other_output_type.clone()] + vec![ + branch_type.clone(), + nonlocal1_type.clone(), + nonlocal2_type.clone(), + Type::UNIT, + ], + vec![Type::UNIT, other_output_type.clone()], )) .unwrap(); let [b, nl1, nl2, unit] = outer.input_wires_arr(); let [unit, out] = { - let mut cfg = outer.cfg_builder([(Type::UNIT, unit), (branch_type.clone(), b)], vec![Type::UNIT, other_output_type.clone()].into()).unwrap(); + let mut cfg = outer + .cfg_builder( + [(Type::UNIT, unit), (branch_type.clone(), b)], + vec![Type::UNIT, other_output_type.clone()].into(), + ) + .unwrap(); let entry = { - let entry = cfg.entry_builder(branch_variants.clone(), other_output_type.clone().into()).unwrap(); + let entry = cfg + .entry_builder(branch_variants.clone(), other_output_type.clone().into()) + .unwrap(); let [_, b] = entry.input_wires_arr(); entry.finish_with_outputs(b, [b]).unwrap() @@ -880,25 +889,43 @@ mod test { let bb_left = { let mut entry = cfg - .block_builder(vec![Type::UNIT, other_output_type.clone()].into(), [type_row![]], other_output_type.clone().into()) + .block_builder( + vec![Type::UNIT, other_output_type.clone()].into(), + [type_row![]], + other_output_type.clone().into(), + ) .unwrap(); let [unit, oo] = entry.input_wires_arr(); - let [_] = entry.add_dataflow_op(Noop::new(nonlocal1_type), [nl1]).unwrap().outputs_arr(); - let [_] = entry.add_dataflow_op(Noop::new(nonlocal2_type), [nl2]).unwrap().outputs_arr(); + let [_] = entry + .add_dataflow_op(Noop::new(nonlocal1_type), [nl1]) + .unwrap() + .outputs_arr(); + let [_] = entry + .add_dataflow_op(Noop::new(nonlocal2_type), [nl2]) + .unwrap() + .outputs_arr(); entry.finish_with_outputs(unit, [oo]).unwrap() }; let bb_right = { let entry = cfg - .block_builder(vec![Type::UNIT, other_output_type.clone()].into(), [type_row![]], other_output_type.clone().into()) + .block_builder( + vec![Type::UNIT, other_output_type.clone()].into(), + [type_row![]], + other_output_type.clone().into(), + ) .unwrap(); - let [b, oo] = entry.input_wires_arr(); + let [_b, oo] = entry.input_wires_arr(); entry.finish_with_outputs(unit, [oo]).unwrap() }; let bb_bottom = { let entry = cfg - .block_builder(branch_type.clone().into(), branch_variants, other_output_type.clone().into()) + .block_builder( + branch_type.clone().into(), + branch_variants, + other_output_type.clone().into(), + ) .unwrap(); let [oo] = entry.input_wires_arr(); entry.finish_with_outputs(oo, [oo]).unwrap() From de0c513b46af102f93a064ae98ad0a2ff1df752a Mon Sep 17 00:00:00 2001 From: Douglas Wilson Date: Sun, 9 Feb 2025 08:54:54 +0000 Subject: [PATCH 13/23] feat: Add `HugrMutInternals::insert_ports` --- hugr-core/src/hugr/internal.rs | 97 +++++++++++++++++++++++++++++++++- 1 file changed, 96 insertions(+), 1 deletion(-) diff --git a/hugr-core/src/hugr/internal.rs b/hugr-core/src/hugr/internal.rs index 3f1c6b6ff7..1f67ff873a 100644 --- a/hugr-core/src/hugr/internal.rs +++ b/hugr-core/src/hugr/internal.rs @@ -6,7 +6,8 @@ use std::rc::Rc; use std::sync::Arc; use delegate::delegate; -use portgraph::{LinkView, MultiPortGraph, PortMut, PortView}; +use itertools::Itertools; +use portgraph::{LinkMut, LinkView, MultiPortGraph, PortMut, PortOffset, PortView}; use crate::ops::handle::NodeHandle; use crate::ops::OpTrait; @@ -174,6 +175,26 @@ pub trait HugrMutInternals: RootTagged { self.hugr_mut().add_ports(node, direction, amount) } + /// Insert `amount` new ports for a node, starting at `index`. The + /// `direction` parameter specifies whether to add ports to the incoming or + /// outgoing list. Links from this node are preserved, even when ports are + /// renumbered by the insertion. + /// + /// Returns the range of newly created ports. + /// # Panics + /// + /// If the node is not in the graph. + fn insert_ports( + &mut self, + node: Node, + direction: Direction, + index: usize, + amount: usize, + ) -> Range { + panic_invalid_node(self, node); + self.hugr_mut().insert_ports(node, direction, index, amount) + } + /// Sets the parent of a node. /// /// The node becomes the parent's last child. @@ -260,6 +281,46 @@ impl + AsMut> HugrMutInternals for T { .set_num_ports(node.pg_index(), incoming, outgoing, |_, _| {}) } + fn insert_ports( + &mut self, + node: Node, + direction: Direction, + index: usize, + amount: usize, + ) -> Range { + let old_num_ports = match direction { + Direction::Incoming => self.base_hugr().graph.num_inputs(node.pg_index()), + Direction::Outgoing => self.base_hugr().graph.num_outputs(node.pg_index()), + }; + + let new_ports = self.add_ports(node, direction, amount as isize); + + for swap_from_port in (index..old_num_ports).rev() { + let swap_to_port = swap_from_port + amount; + let [from_port_index, to_port_index] = [swap_from_port, swap_to_port].map(|p| { + self.base_hugr() + .graph + .port_index(node.pg_index(), PortOffset::new(direction, p)) + .unwrap() + }); + let linked_ports = self + .base_hugr() + .graph + .port_links(from_port_index) + .map(|(_, to_subport)| to_subport.port()) + .collect_vec(); + self.hugr_mut().graph.unlink_port(from_port_index); + for linked_port_index in linked_ports { + let _ = self + .hugr_mut() + .graph + .link_ports(to_port_index, linked_port_index) + .expect("Ports exist"); + } + } + index..new_ports.len() + } + fn add_ports(&mut self, node: Node, direction: Direction, amount: isize) -> Range { let mut incoming = self.hugr_mut().graph.num_inputs(node.pg_index()); let mut outgoing = self.hugr_mut().graph.num_outputs(node.pg_index()); @@ -309,3 +370,37 @@ impl + AsMut> HugrMutInternals for T { Ok(std::mem::replace(cur, op.into())) } } + +#[cfg(test)] +mod test { + use crate::{ + builder::{DFGBuilder, Dataflow, DataflowHugr}, + extension::prelude::Noop, + hugr::internal::HugrMutInternals as _, + ops::handle::NodeHandle, + types::{Signature, Type}, + Direction, HugrView as _, + }; + + #[test] + fn insert_ports() { + let (nop, mut hugr) = { + let mut builder = DFGBuilder::new(Signature::new_endo(Type::UNIT)).unwrap(); + let [nop_in] = builder.input_wires_arr(); + let nop = builder + .add_dataflow_op(Noop::new(Type::UNIT), [nop_in]) + .unwrap(); + let [nop_out] = nop.outputs_arr(); + ( + nop.node(), + builder.finish_hugr_with_outputs([nop_out]).unwrap(), + ) + }; + let [i, o] = hugr.get_io(hugr.root()).unwrap(); + hugr.insert_ports(nop, Direction::Incoming, 0, 2); + hugr.insert_ports(nop, Direction::Outgoing, 0, 2); + + assert_eq!(hugr.single_linked_input(i, 0), Some((nop, 2.into()))); + assert_eq!(hugr.single_linked_output(o, 0), Some((nop, 2.into()))); + } +} From f096dd727bbadb3f9c6e01e3317434437bd526dc Mon Sep 17 00:00:00 2001 From: Douglas Wilson Date: Mon, 10 Feb 2025 11:25:03 +0000 Subject: [PATCH 14/23] fixes --- hugr-core/src/hugr/internal.rs | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/hugr-core/src/hugr/internal.rs b/hugr-core/src/hugr/internal.rs index 1f67ff873a..51cb7af584 100644 --- a/hugr-core/src/hugr/internal.rs +++ b/hugr-core/src/hugr/internal.rs @@ -293,7 +293,7 @@ impl + AsMut> HugrMutInternals for T { Direction::Outgoing => self.base_hugr().graph.num_outputs(node.pg_index()), }; - let new_ports = self.add_ports(node, direction, amount as isize); + self.add_ports(node, direction, amount as isize); for swap_from_port in (index..old_num_ports).rev() { let swap_to_port = swap_from_port + amount; @@ -318,7 +318,7 @@ impl + AsMut> HugrMutInternals for T { .expect("Ports exist"); } } - index..new_ports.len() + index..index + amount } fn add_ports(&mut self, node: Node, direction: Direction, amount: isize) -> Range { @@ -374,7 +374,7 @@ impl + AsMut> HugrMutInternals for T { #[cfg(test)] mod test { use crate::{ - builder::{DFGBuilder, Dataflow, DataflowHugr}, + builder::{Container, DFGBuilder, Dataflow, DataflowHugr}, extension::prelude::Noop, hugr::internal::HugrMutInternals as _, ops::handle::NodeHandle, @@ -390,6 +390,7 @@ mod test { let nop = builder .add_dataflow_op(Noop::new(Type::UNIT), [nop_in]) .unwrap(); + builder.add_other_wire(nop.node(), builder.output().node()); let [nop_out] = nop.outputs_arr(); ( nop.node(), @@ -397,10 +398,11 @@ mod test { ) }; let [i, o] = hugr.get_io(hugr.root()).unwrap(); - hugr.insert_ports(nop, Direction::Incoming, 0, 2); - hugr.insert_ports(nop, Direction::Outgoing, 0, 2); + assert_eq!(0..2, hugr.insert_ports(nop, Direction::Incoming, 0, 2)); + assert_eq!(1..3, hugr.insert_ports(nop, Direction::Outgoing, 1, 2)); assert_eq!(hugr.single_linked_input(i, 0), Some((nop, 2.into()))); - assert_eq!(hugr.single_linked_output(o, 0), Some((nop, 2.into()))); + assert_eq!(hugr.single_linked_output(o, 0), Some((nop, 0.into()))); + assert_eq!(hugr.single_linked_output(o, 1), Some((nop, 3.into()))); } } From 504afe507eee5203431993156f926fccd0046aa8 Mon Sep 17 00:00:00 2001 From: Douglas Wilson Date: Mon, 10 Feb 2025 11:33:42 +0000 Subject: [PATCH 15/23] with_prelude --- hugr-core/src/hugr/internal.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/hugr-core/src/hugr/internal.rs b/hugr-core/src/hugr/internal.rs index 51cb7af584..33d791266d 100644 --- a/hugr-core/src/hugr/internal.rs +++ b/hugr-core/src/hugr/internal.rs @@ -385,7 +385,8 @@ mod test { #[test] fn insert_ports() { let (nop, mut hugr) = { - let mut builder = DFGBuilder::new(Signature::new_endo(Type::UNIT)).unwrap(); + let mut builder = + DFGBuilder::new(Signature::new_endo(Type::UNIT).with_prelude()).unwrap(); let [nop_in] = builder.input_wires_arr(); let nop = builder .add_dataflow_op(Noop::new(Type::UNIT), [nop_in]) From 6cc87a465edb5bba832b48300b2d514c790fa2c0 Mon Sep 17 00:00:00 2001 From: Douglas Wilson Date: Mon, 10 Feb 2025 14:26:20 +0000 Subject: [PATCH 16/23] works --- hugr-core/src/hugr/hugrmut.rs | 45 +--- hugr-passes/src/non_local.rs | 412 ++++++++++++++++++---------------- 2 files changed, 214 insertions(+), 243 deletions(-) diff --git a/hugr-core/src/hugr/hugrmut.rs b/hugr-core/src/hugr/hugrmut.rs index b76d1897fe..4056f36e61 100644 --- a/hugr-core/src/hugr/hugrmut.rs +++ b/hugr-core/src/hugr/hugrmut.rs @@ -4,7 +4,6 @@ use core::panic; use std::collections::HashMap; use std::sync::Arc; -use itertools::Itertools as _; use portgraph::view::{NodeFilter, NodeFiltered}; use portgraph::{LinkMut, NodeIndex, PortMut, PortView, SecondaryMap}; @@ -12,7 +11,7 @@ use crate::extension::ExtensionRegistry; use crate::hugr::views::SiblingSubgraph; use crate::hugr::{HugrView, Node, OpType, RootTagged}; use crate::hugr::{NodeMetadata, Rewrite}; -use crate::{Direction, Extension, Hugr, IncomingPort, OutgoingPort, Port, PortIndex}; +use crate::{Extension, Hugr, IncomingPort, OutgoingPort, Port, PortIndex}; use super::internal::HugrMutInternals; use super::NodeMetadataMap; @@ -279,48 +278,6 @@ pub trait HugrMut: HugrMutInternals { fn extensions_mut(&mut self) -> &mut ExtensionRegistry { &mut self.hugr_mut().extensions } - - /// TODO perhaps these should be on HugrMut? - fn insert_incoming_port(&mut self, node: Node, index: usize) -> IncomingPort { - let _ = self - .add_ports(node, Direction::Incoming, 1) - .exactly_one() - .unwrap(); - - for (to, from) in (index..self.num_inputs(node)) - .map_into::() - .rev() - .tuple_windows() - { - let linked_outputs = self.linked_outputs(node, from).collect_vec(); - self.disconnect(node, from); - for (linked_node, linked_port) in linked_outputs { - self.connect(linked_node, linked_port, node, to); - } - } - index.into() - } - - /// TODO perhaps these should be on HugrMut? - fn insert_outgoing_port(&mut self, node: Node, index: usize) -> OutgoingPort { - let _ = self - .add_ports(node, Direction::Outgoing, 1) - .exactly_one() - .unwrap(); - - for (to, from) in (index..self.num_outputs(node)) - .map_into::() - .rev() - .tuple_windows() - { - let linked_inputs = self.linked_inputs(node, from).collect_vec(); - self.disconnect(node, from); - for (linked_node, linked_port) in linked_inputs { - self.connect(node, to, linked_node, linked_port); - } - } - index.into() - } } /// Records the result of inserting a Hugr or view diff --git a/hugr-passes/src/non_local.rs b/hugr-passes/src/non_local.rs index 8600f19c50..20344bc6a6 100644 --- a/hugr-passes/src/non_local.rs +++ b/hugr-passes/src/non_local.rs @@ -7,7 +7,7 @@ use std::{ }; //TODO Add `remove_nonlocal_edges` and `add_nonlocal_edges` functions -use itertools::{Either, Itertools as _}; +use itertools::Itertools as _; use hugr_core::{ builder::{ConditionalBuilder, Dataflow, DataflowSubContainer, HugrBuilder}, @@ -18,7 +18,7 @@ use hugr_core::{ }, ops::{DataflowOpTrait as _, OpType, Tag, TailLoop}, types::{EdgeKind, Type, TypeRow}, - HugrView, IncomingPort, Node, PortIndex, Wire, + Direction, HugrView, IncomingPort, Node, PortIndex, Wire, }; use crate::validation::{ValidatePassError, ValidationLevel}; @@ -95,24 +95,34 @@ struct WorkItem { } #[derive(Clone, Default, Debug)] -struct ParentSourceMap(HashMap>); +struct ParentSourceMap(HashMap>); impl ParentSourceMap { - // fn contains_parent(&self, parent: Node) -> bool { - // self.0.contains_key(&parent) - // } - fn insert_sources_in_parent( &mut self, parent: Node, - sources: impl IntoIterator, + sources: impl IntoIterator, ) { debug_assert!(!self.0.contains_key(&parent)); - self.0.entry(parent).or_default().extend(sources); + self.0 + .entry(parent) + .or_default() + .extend(sources.into_iter().map(|(s, p, t)| (s, (p, t)))); } - fn get_source_in_parent(&self, parent: Node, source: Wire) -> Option { - self.0.get(&parent).and_then(|m| m.get(&source).cloned()) + fn get_source_in_parent( + &self, + parent: Node, + source: Wire, + ref hugr: impl HugrView, + ) -> (Wire, Type) { + let r @ (w, _) = self + .0 + .get(&parent) + .and_then(|m| m.get(&source).cloned()) + .unwrap(); + debug_assert_eq!(hugr.get_parent(w.node()).unwrap(), parent); + r } fn thread_dataflow_parent( @@ -121,26 +131,122 @@ impl ParentSourceMap { parent: Node, start_port_index: usize, sources: impl IntoIterator, - ) -> impl Iterator { - let [input_n, _] = hugr.get_io(parent).unwrap(); - let OpType::Input(mut input) = hugr.get_optype(input_n).clone() else { + ) { + let (source_wires, source_types): (Vec<_>, Vec<_>) = sources.into_iter().unzip(); + let input_wires = { + let [input_n, _] = hugr.get_io(parent).unwrap(); + let Some(mut input) = hugr.get_optype(input_n).as_input().cloned() else { + panic!("impossible") + }; + vec_insert(input.types.to_mut(), source_types.clone(), start_port_index); + hugr.replace_op(input_n, input).unwrap(); + hugr.insert_ports( + input_n, + Direction::Outgoing, + start_port_index, + source_wires.len(), + ) + .map(move |new_port| Wire::new(input_n, new_port)) + .collect_vec() + }; + self.insert_sources_in_parent( + parent, + itertools::izip!(source_wires, input_wires, source_types), + ); + } +} + +#[derive(Clone, Debug)] +struct ControlWorkItem { + output_node: Node, + variant_source_prefixes: Vec>, +} + +impl ControlWorkItem { + fn go(self, hugr: &mut impl HugrMut, psm: &ParentSourceMap) { + let parent = hugr.get_parent(self.output_node).unwrap(); + let Some(mut output) = hugr.get_optype(self.output_node).as_output().cloned() else { panic!("impossible") }; - let mut input_wires = vec![]; - self.0 - .entry(parent) - .or_default() - .extend(sources.into_iter().enumerate().map(|(i, (source, ty))| { - input.types.to_mut().insert(start_port_index + i, ty); - let input_wire = Wire::new( - input_n, - hugr.insert_outgoing_port(input_n, start_port_index + i), - ); - input_wires.push(input_wire); - (source, input_wire) - })); - hugr.replace_op(input_n, input).unwrap(); - input_wires.into_iter() + let mut needed_sources = BTreeMap::new(); + let (cond, new_control_type) = { + let Some(EdgeKind::Value(control_type)) = hugr + .get_optype(self.output_node) + .port_kind(IncomingPort::from(0)) + else { + panic!("impossible") + }; + let Some(sum_type) = control_type.as_sum() else { + panic!("impossible") + }; + + let mut type_for_source = |source: &Wire| { + let (w, t) = psm.get_source_in_parent(parent, *source, &hugr); + let replaced = needed_sources.insert(*source, (w, t.clone())); + debug_assert!(!replaced.is_some_and(|x| x != (w, t.clone()))); + t + }; + let old_sum_rows: Vec = sum_type + .variants() + .map(|x| x.clone().try_into().unwrap()) + .collect_vec(); + let new_sum_rows: Vec = + itertools::zip_eq(self.variant_source_prefixes.iter(), old_sum_rows.iter()) + .map(|(new_sources, old_tys)| { + new_sources + .iter() + .map(&mut type_for_source) + .chain(old_tys.iter().cloned()) + .collect_vec() + .into() + }) + .collect_vec(); + + let new_control_type = Type::new_sum(new_sum_rows.clone()); + let mut cond = ConditionalBuilder::new( + old_sum_rows.clone(), + needed_sources + .values() + .map(|(_, t)| t.clone()) + .collect_vec(), + new_control_type.clone(), + ) + .unwrap(); + for (i, new_sources) in self.variant_source_prefixes.into_iter().enumerate() { + let mut case = cond.case_builder(i).unwrap(); + let case_inputs = case.input_wires().collect_vec(); + let mut args = new_sources + .into_iter() + .map(|s| { + case_inputs[old_sum_rows[i].len() + + needed_sources + .iter() + .find_position(|(&w, _)| w == s) + .unwrap() + .0] + }) + .collect_vec(); + args.extend(&case_inputs[..old_sum_rows[i].len()]); + let case_outputs = case + .add_dataflow_op(Tag::new(i, new_sum_rows.clone()), args) + .unwrap() + .outputs(); + case.finish_with_outputs(case_outputs).unwrap(); + } + (cond.finish_hugr().unwrap(), new_control_type) + }; + let cond_node = hugr.insert_hugr(parent, cond).new_root; + let (old_output_source_node, old_output_source_port) = + hugr.single_linked_output(self.output_node, 0).unwrap(); + debug_assert_eq!(hugr.get_parent(old_output_source_node).unwrap(), parent); + hugr.connect(old_output_source_node, old_output_source_port, cond_node, 0); + for (i, &(w, _)) in needed_sources.values().enumerate() { + hugr.connect(w.node(), w.source(), cond_node, i + 1); + } + hugr.disconnect(self.output_node, IncomingPort::from(0)); + hugr.connect(cond_node, 0, self.output_node, 0); + output.types.to_mut()[0] = new_control_type; + hugr.replace_op(self.output_node, output).unwrap(); } } @@ -149,21 +255,19 @@ struct ThreadState<'a> { parent_source_map: ParentSourceMap, needs: &'a BBNeedsSourcesMap, worklist: Vec, + control_worklist: Vec, } impl<'a> ThreadState<'a> { delegate! { to self.parent_source_map { - // fn contains_parent(&self, parent: Node) -> bool; - // fn get_source_in_parent(&self, parent: Node, source: Wire) -> Option; - fn insert_sources_in_parent(&mut self, parent: Node, sources: impl IntoIterator); fn thread_dataflow_parent( &mut self, hugr: &mut impl HugrMut, parent: Node, start_port_index: usize, sources: impl IntoIterator, - ) -> impl Iterator; + ); } } @@ -172,6 +276,7 @@ impl<'a> ThreadState<'a> { parent_source_map: Default::default(), needs: bbnsm, worklist: vec![], + control_worklist: vec![], } } @@ -189,14 +294,20 @@ impl<'a> ThreadState<'a> { for (this_p, succ_n) in hugr.node_outputs(node).filter_map(|out_p| { let (succ_n, _) = hugr.single_linked_input(node, out_p).unwrap(); - hugr.get_optype(succ_n).is_dataflow_block().then_some((out_p.index(), succ_n)) + hugr.get_optype(succ_n) + .is_dataflow_block() + .then_some((out_p.index(), succ_n)) }) { let succ_needs = &self.needs[&succ_n]; let succ_needs_source_indices = succ_needs .iter() .map(|(&w, _)| sources.iter().find_position(|(x, _)| x == &w).unwrap().0) .collect_vec(); - let succ_needs_tys = succ_needs_source_indices.iter().copied().map(|x| sources[x].1.clone()).collect_vec(); + let succ_needs_tys = succ_needs_source_indices + .iter() + .copied() + .map(|x| sources[x].1.clone()) + .collect_vec(); vec_prepend(this_dfb.sum_rows[this_p].to_mut(), succ_needs_tys); nsrp[this_p] = succ_needs_source_indices; } @@ -204,88 +315,28 @@ impl<'a> ThreadState<'a> { nsrp }; - let input_wires = self - .thread_dataflow_parent(hugr, node, 0, sources.clone()) - .collect_vec(); + self.thread_dataflow_parent(hugr, node, 0, sources.clone()); let [_, o] = hugr.get_io(node).unwrap(); - let (cond, new_control_type) = { - let Some(EdgeKind::Value(control_type)) = - hugr.get_optype(o).port_kind(IncomingPort::from(0)) - else { - panic!("impossible") - }; - let Some(sum_type) = control_type.as_sum() else { - panic!("impossible") - }; - - let old_sum_rows: Vec = sum_type - .variants() - .map(|x| x.clone().try_into().unwrap()) - .collect_vec(); - let new_sum_rows: Vec = - itertools::zip_eq(new_sum_row_prefixes.iter(), old_sum_rows.iter()) - .map(|(new_source_indices, old_tys)| { - new_source_indices.into_iter() - .map(|&x| sources[x].1.clone()) - .chain(old_tys.iter().cloned()) - .collect_vec() - .into() - }) - .collect_vec(); - - let new_control_type = Type::new_sum(new_sum_rows.clone()); - let mut cond = ConditionalBuilder::new( - old_sum_rows.clone(), - types.clone(), - new_control_type.clone(), - ) - .unwrap(); - for (i, new_source_indices) in new_sum_row_prefixes.into_iter().enumerate() { - let mut case = cond.case_builder(i).unwrap(); - let case_inputs = case.input_wires().collect_vec(); - let mut args = vec![]; - for source_i in new_source_indices { - args.push(case_inputs[old_sum_rows[i].len() + source_i]); - } - - args.extend(&case_inputs[..old_sum_rows[i].len()]); - - let case_outputs = case - .add_dataflow_op(Tag::new(i, new_sum_rows.clone()), args) - .unwrap() - .outputs(); - case.finish_with_outputs(case_outputs).unwrap(); - } - (cond.finish_hugr().unwrap(), new_control_type) - }; - let cond_node = hugr.insert_hugr(node, cond).new_root; - let (n, p) = hugr.single_linked_output(o, 0).unwrap(); - hugr.connect(n, p, cond_node, 0); - for (i, w) in input_wires.into_iter().enumerate() { - hugr.connect(w.node(), w.source(), cond_node, i + 1); - } - hugr.disconnect(o, IncomingPort::from(0)); - hugr.connect(cond_node, 0, o, 0); - let mut output = hugr.get_optype(o).as_output().unwrap().clone(); - output.types.to_mut()[0] = new_control_type; - hugr.replace_op(o, output).unwrap(); - dbg!(hugr.single_linked_output(o, 0)); + self.control_worklist.push(ControlWorkItem { + output_node: o, + variant_source_prefixes: new_sum_row_prefixes + .into_iter() + .map(|v| v.into_iter().map(|i| sources[i].0.clone()).collect_vec()) + .collect_vec(), + }); } fn do_cfg(&mut self, hugr: &mut impl HugrMut, node: Node, sources: Vec<(Wire, Type)>) { let types = sources.iter().map(|x| x.1.clone()).collect_vec(); { let mut cfg = hugr.get_optype(node).as_cfg().unwrap().clone(); - let inputs = cfg.signature.input.to_mut(); - let old_inputs = mem::replace(inputs, types); - inputs.extend(old_inputs); + vec_insert(cfg.signature.input.to_mut(), types, 0); hugr.replace_op(node, cfg).unwrap(); } - let new_cond_ports = (0..sources.len()) - .map(|i| hugr.insert_incoming_port(node, i)) - .collect_vec(); - self.insert_sources_in_parent(node, iter::empty()); + let new_cond_ports = hugr + .insert_ports(node, Direction::Incoming, 0, sources.len()) + .map_into(); self.worklist .extend(mk_workitems(node, sources, new_cond_ports)) } @@ -293,16 +344,20 @@ impl<'a> ThreadState<'a> { fn do_dfg(&mut self, hugr: &mut impl HugrMut, node: Node, sources: Vec<(Wire, Type)>) { let mut dfg = hugr.get_optype(node).as_dfg().unwrap().clone(); let start_new_port_index = dfg.signature.input().len(); - let new_dfg_ports = (0..sources.len()) - .map(|i| hugr.insert_incoming_port(node, start_new_port_index + i)) - .collect_vec(); + let new_dfg_ports = hugr + .insert_ports( + node, + Direction::Incoming, + start_new_port_index, + sources.len(), + ) + .map_into(); dfg.signature .input .to_mut() .extend(sources.iter().map(|x| x.1.clone())); hugr.replace_op(node, dfg).unwrap(); - let _ = - self.thread_dataflow_parent(hugr, node, start_new_port_index, sources.iter().cloned()); + self.thread_dataflow_parent(hugr, node, start_new_port_index, sources.iter().cloned()); self.worklist .extend(mk_workitems(node, sources, new_dfg_ports)); } @@ -314,10 +369,14 @@ impl<'a> ThreadState<'a> { .to_mut() .extend(sources.iter().map(|x| x.1.clone())); hugr.replace_op(node, cond).unwrap(); - let new_cond_ports = (0..sources.len()) - .map(|i| hugr.insert_incoming_port(node, start_new_port_index + i)) - .collect_vec(); - self.insert_sources_in_parent(node, iter::empty()); + let new_cond_ports = hugr + .insert_ports( + node, + Direction::Incoming, + start_new_port_index, + sources.len(), + ) + .map_into(); self.worklist .extend(mk_workitems(node, sources, new_cond_ports)) } @@ -330,95 +389,48 @@ impl<'a> ThreadState<'a> { .to_mut() .extend(sources.iter().map(|x| x.1.clone())); hugr.replace_op(node, case).unwrap(); - let _ = self.thread_dataflow_parent(hugr, node, start_case_port_index, sources); + self.thread_dataflow_parent(hugr, node, start_case_port_index, sources); } fn do_tailloop(&mut self, hugr: &mut impl HugrMut, node: Node, sources: Vec<(Wire, Type)>) { let mut tailloop = hugr.get_optype(node).as_tail_loop().unwrap().clone(); let types = sources.iter().map(|x| x.1.clone()).collect_vec(); - let start_port_index = tailloop.just_inputs.len(); { - tailloop.just_inputs.to_mut().extend(types.clone()); + vec_prepend(tailloop.just_inputs.to_mut(), types.clone()); hugr.replace_op(node, tailloop).unwrap(); } - let tailloop_ports = (0..sources.len()) - .map(|i| hugr.insert_incoming_port(node, start_port_index + i)) - .collect_vec(); + let tailloop_ports = hugr + .insert_ports(node, Direction::Incoming, 0, sources.len()) + .map_into(); - let input_wires = self - .thread_dataflow_parent(hugr, node, start_port_index, sources.clone()) - .collect_vec(); + self.thread_dataflow_parent(hugr, node, 0, sources.clone()); let [_, o] = hugr.get_io(node).unwrap(); - let (cond, new_control_type) = { - let Some(EdgeKind::Value(control_type)) = - hugr.get_optype(o).port_kind(IncomingPort::from(0)) - else { - panic!("impossible") - }; - let Some(sum_type) = control_type.as_sum() else { - panic!("impossible") - }; - - let old_sum_rows: Vec = sum_type - .variants() - .map(|x| x.clone().try_into().unwrap()) - .collect_vec(); - let new_sum_rows = { - let mut v = old_sum_rows.clone(); - v[TailLoop::CONTINUE_TAG] - .to_mut() - .extend(types.iter().cloned()); - v - }; - - let new_control_type = Type::new_sum(new_sum_rows.clone()); - let mut cond = - ConditionalBuilder::new(old_sum_rows, types.clone(), new_control_type.clone()) - .unwrap(); - for i in 0..2 { - let mut case = cond.case_builder(i).unwrap(); - let inputs = { - let all_inputs = case.input_wires(); - if i == TailLoop::CONTINUE_TAG { - Either::Left(all_inputs) - } else { - Either::Right(all_inputs.into_iter().dropping_back(types.len())) - } - }; - - let case_outputs = case - .add_dataflow_op(Tag::new(i, new_sum_rows.clone()), inputs) - .unwrap() - .outputs(); - case.finish_with_outputs(case_outputs).unwrap(); - } - (cond.finish_hugr().unwrap(), new_control_type) + let new_sum_row_prefixes = { + let mut v = vec![vec![]; 2]; + v[TailLoop::CONTINUE_TAG].extend(sources.iter().map(|x| x.0)); + v }; - let cond_node = hugr.insert_hugr(node, cond).new_root; - let (n, p) = hugr.single_linked_output(o, 0).unwrap(); - hugr.connect(n, p, cond_node, 0); - for (i, w) in input_wires.into_iter().enumerate() { - hugr.connect(w.node(), w.source(), cond_node, i + 1); - } - hugr.disconnect(o, IncomingPort::from(0)); - hugr.connect(cond_node, 0, o, 0); - let mut output = hugr.get_optype(o).as_output().unwrap().clone(); - output.types.to_mut()[0] = new_control_type; - hugr.replace_op(o, output).unwrap(); + self.control_worklist.push(ControlWorkItem { + output_node: o, + variant_source_prefixes: new_sum_row_prefixes, + }); self.worklist .extend(mk_workitems(node, sources, tailloop_ports)) } - fn finish(self, _hugr: &mut impl HugrMut) -> (Vec, ParentSourceMap) { - (self.worklist, self.parent_source_map) + fn finish( + self, + _hugr: &mut impl HugrMut, + ) -> (Vec, ParentSourceMap, Vec) { + (self.worklist, self.parent_source_map, self.control_worklist) } } fn thread_sources( hugr: &mut impl HugrMut, bb_needs_sources_map: &BBNeedsSourcesMap, -) -> (Vec, ParentSourceMap) { +) -> (Vec, ParentSourceMap, Vec) { let mut state = ThreadState::new(bb_needs_sources_map); for (&bb, sources) in bb_needs_sources_map { let sources = sources.iter().map(|(&w, ty)| (w, ty.clone())).collect_vec(); @@ -516,11 +528,6 @@ impl BBNeedsSourcesMapBuilder { .children(cfg) .filter(|&child| hugr.get_optype(child).is_dataflow_block()) .collect_vec(); - - // let mut dfb_needs_map: HashMap<_, _> = dfbs - // .iter() - // .map(|&n| (n, self.0.get(&n).cloned().unwrap_or_default())) - // .collect(); loop { let mut any_change = false; for &dfb in dfbs.iter() { @@ -627,11 +634,11 @@ pub fn remove_nonlocal_edges( } } - let (parent_source_map, worklist) = { + let (parent_source_map, worklist, control_worklist) = { let mut worklist = nonlocal_edges_map.into_values().collect_vec(); - let (wl, psm) = thread_sources(hugr, &bb_needs_sources_map); + let (wl, psm, control_worklist) = thread_sources(hugr, &bb_needs_sources_map); worklist.extend(wl); - (psm, worklist) + (psm, worklist, control_worklist) }; for wi in worklist { @@ -640,28 +647,35 @@ pub fn remove_nonlocal_edges( wi.source } else { parent_source_map - .get_source_in_parent(parent, wi.source) - .unwrap() + .get_source_in_parent(parent, wi.source, &hugr) + .0 }; debug_assert_eq!(hugr.get_parent(source.node()), hugr.get_parent(wi.target.0)); hugr.disconnect(wi.target.0, wi.target.1); hugr.connect(source.node(), source.source(), wi.target.0, wi.target.1); } + for cwi in control_worklist { + cwi.go(hugr, &parent_source_map) + } + Ok(()) } fn vec_prepend(v: &mut Vec, ts: impl IntoIterator) { - let mut old_v = mem::replace(v, ts.into_iter().collect()); - v.extend(old_v.drain(..)); + vec_insert(v, ts, 0) +} + +fn vec_insert(v: &mut Vec, ts: impl IntoIterator, index: usize) { + let mut old_v_iter = mem::replace(v, vec![]).into_iter(); + v.extend(old_v_iter.by_ref().take(index).chain(ts)); + v.extend(old_v_iter); } #[cfg(test)] mod test { use hugr_core::{ - builder::{ - DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, SubContainer, - }, + builder::{DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, SubContainer}, extension::prelude::{bool_t, either_type, Noop}, ops::{handle::NodeHandle, Tag, TailLoop, Value}, type_row, @@ -722,7 +736,7 @@ mod test { let [w0] = outer.input_wires_arr(); let [w1] = { let inner = outer - .dfg_builder(Signature::new(type_row![], bool_t()), []) + .dfg_builder(Signature::new_endo(bool_t()), [w0]) .unwrap(); inner.finish_with_outputs([w0]).unwrap().outputs_arr() }; From 461a5ab6178c8ffccccd484935b52a90133cd3e8 Mon Sep 17 00:00:00 2001 From: Douglas Wilson Date: Mon, 10 Feb 2025 14:27:53 +0000 Subject: [PATCH 17/23] wip --- hugr-passes/src/non_local.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/hugr-passes/src/non_local.rs b/hugr-passes/src/non_local.rs index 20344bc6a6..0293323a9e 100644 --- a/hugr-passes/src/non_local.rs +++ b/hugr-passes/src/non_local.rs @@ -322,7 +322,7 @@ impl<'a> ThreadState<'a> { output_node: o, variant_source_prefixes: new_sum_row_prefixes .into_iter() - .map(|v| v.into_iter().map(|i| sources[i].0.clone()).collect_vec()) + .map(|v| v.into_iter().map(|i| sources[i].0).collect_vec()) .collect_vec(), }); } @@ -667,7 +667,7 @@ fn vec_prepend(v: &mut Vec, ts: impl IntoIterator) { } fn vec_insert(v: &mut Vec, ts: impl IntoIterator, index: usize) { - let mut old_v_iter = mem::replace(v, vec![]).into_iter(); + let mut old_v_iter = mem::take(v).into_iter(); v.extend(old_v_iter.by_ref().take(index).chain(ts)); v.extend(old_v_iter); } From 3a07aa3f1d6c8e52bfe2ac03d3741f815220fb38 Mon Sep 17 00:00:00 2001 From: Douglas Wilson Date: Mon, 10 Feb 2025 14:31:57 +0000 Subject: [PATCH 18/23] get get_optype_mut --- hugr-core/src/hugr/internal.rs | 13 +------------ 1 file changed, 1 insertion(+), 12 deletions(-) diff --git a/hugr-core/src/hugr/internal.rs b/hugr-core/src/hugr/internal.rs index b797335c9a..54d9004cdc 100644 --- a/hugr-core/src/hugr/internal.rs +++ b/hugr-core/src/hugr/internal.rs @@ -266,13 +266,6 @@ pub trait HugrMutInternals: RootTagged { } self.hugr_mut().replace_op(node, op) } - - /// TODO docs - fn get_optype_mut(&mut self, node: Node) -> Result<&mut OpType, HugrError> { - panic_invalid_node(self, node); - // TODO refuse if node == self.root() because tag might be violated - self.hugr_mut().get_optype_mut(node) - } } /// Impl for non-wrapped Hugrs. Overwrites the recursive default-impls to directly use the hugr. @@ -371,14 +364,10 @@ impl + AsMut> HugrMutInternals for T { fn replace_op(&mut self, node: Node, op: impl Into) -> Result { // We know RootHandle=Node here so no need to check Ok(std::mem::replace( - self.hugr_mut().get_optype_mut(node)?, + self.hugr_mut().op_types.get_mut(node.pg_index()), op.into(), )) } - - fn get_optype_mut(&mut self, node: Node) -> Result<&mut OpType, HugrError> { - Ok(self.hugr_mut().op_types.get_mut(node.pg_index())) - } } #[cfg(test)] From 9d251df207c93cdfb67e97ee062ad782d5ac8dc1 Mon Sep 17 00:00:00 2001 From: Douglas Wilson Date: Mon, 10 Feb 2025 14:32:25 +0000 Subject: [PATCH 19/23] fix merge --- hugr-core/src/hugr/internal.rs | 37 ---------------------------------- 1 file changed, 37 deletions(-) diff --git a/hugr-core/src/hugr/internal.rs b/hugr-core/src/hugr/internal.rs index 54d9004cdc..3e98ac0f20 100644 --- a/hugr-core/src/hugr/internal.rs +++ b/hugr-core/src/hugr/internal.rs @@ -406,40 +406,3 @@ mod test { assert_eq!(hugr.single_linked_output(o, 1), Some((nop, 3.into()))); } } - -#[cfg(test)] -mod test { - use crate::{ - builder::{Container, DFGBuilder, Dataflow, DataflowHugr}, - extension::prelude::Noop, - hugr::internal::HugrMutInternals as _, - ops::handle::NodeHandle, - types::{Signature, Type}, - Direction, HugrView as _, - }; - - #[test] - fn insert_ports() { - let (nop, mut hugr) = { - let mut builder = - DFGBuilder::new(Signature::new_endo(Type::UNIT).with_prelude()).unwrap(); - let [nop_in] = builder.input_wires_arr(); - let nop = builder - .add_dataflow_op(Noop::new(Type::UNIT), [nop_in]) - .unwrap(); - builder.add_other_wire(nop.node(), builder.output().node()); - let [nop_out] = nop.outputs_arr(); - ( - nop.node(), - builder.finish_hugr_with_outputs([nop_out]).unwrap(), - ) - }; - let [i, o] = hugr.get_io(hugr.root()).unwrap(); - assert_eq!(0..2, hugr.insert_ports(nop, Direction::Incoming, 0, 2)); - assert_eq!(1..3, hugr.insert_ports(nop, Direction::Outgoing, 1, 2)); - - assert_eq!(hugr.single_linked_input(i, 0), Some((nop, 2.into()))); - assert_eq!(hugr.single_linked_output(o, 0), Some((nop, 0.into()))); - assert_eq!(hugr.single_linked_output(o, 1), Some((nop, 3.into()))); - } -} From 9a8a5e257e34abb9bc71c10f17d04ec42347bf9c Mon Sep 17 00:00:00 2001 From: Douglas Wilson Date: Mon, 10 Feb 2025 14:33:51 +0000 Subject: [PATCH 20/23] tweak --- hugr-core/src/hugr/internal.rs | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/hugr-core/src/hugr/internal.rs b/hugr-core/src/hugr/internal.rs index 3e98ac0f20..75b0aab1d2 100644 --- a/hugr-core/src/hugr/internal.rs +++ b/hugr-core/src/hugr/internal.rs @@ -363,10 +363,8 @@ impl + AsMut> HugrMutInternals for T { fn replace_op(&mut self, node: Node, op: impl Into) -> Result { // We know RootHandle=Node here so no need to check - Ok(std::mem::replace( - self.hugr_mut().op_types.get_mut(node.pg_index()), - op.into(), - )) + let cur = self.hugr_mut().op_types.get_mut(node.pg_index()); + Ok(std::mem::replace(cur, op.into())) } } From dd4caa0f0fa8ad7c85f5d8556fd8fe5408c8d66f Mon Sep 17 00:00:00 2001 From: Douglas Wilson Date: Mon, 10 Feb 2025 14:52:29 +0000 Subject: [PATCH 21/23] fmt --- hugr-passes/src/non_local.rs | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/hugr-passes/src/non_local.rs b/hugr-passes/src/non_local.rs index 0293323a9e..cba63b5496 100644 --- a/hugr-passes/src/non_local.rs +++ b/hugr-passes/src/non_local.rs @@ -114,7 +114,7 @@ impl ParentSourceMap { &self, parent: Node, source: Wire, - ref hugr: impl HugrView, + hugr: impl HugrView, ) -> (Wire, Type) { let r @ (w, _) = self .0 @@ -470,7 +470,8 @@ impl BBNeedsSourcesMapBuilder { self.0.entry(bb).or_default().insert(source, ty); } - fn extend_parent_needs_for(&mut self, ref hugr: impl HugrView, child: Node) -> bool { + fn extend_parent_needs_for(&mut self, hugr: impl HugrView, child: Node) -> bool { + let hugr = &hugr; let parent = hugr.get_parent(child).unwrap(); let parent_needs = self .0 From 6dabc6bd5f1b3bd52b26ffe834d39638bb670da9 Mon Sep 17 00:00:00 2001 From: Douglas Wilson Date: Mon, 10 Feb 2025 14:53:45 +0000 Subject: [PATCH 22/23] with_prelude --- hugr-passes/src/non_local.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hugr-passes/src/non_local.rs b/hugr-passes/src/non_local.rs index cba63b5496..4a654ec217 100644 --- a/hugr-passes/src/non_local.rs +++ b/hugr-passes/src/non_local.rs @@ -881,7 +881,7 @@ mod test { Type::UNIT, ], vec![Type::UNIT, other_output_type.clone()], - )) + ).with_prelude()) .unwrap(); let [b, nl1, nl2, unit] = outer.input_wires_arr(); let [unit, out] = { From 322facfc4cb29d64802e8d59fcc459c9b47c3b81 Mon Sep 17 00:00:00 2001 From: Douglas Wilson Date: Wed, 12 Feb 2025 10:19:51 +0000 Subject: [PATCH 23/23] wip --- devenv.lock | 49 +++---- hugr-passes/src/non_local.rs | 253 ++++++++++++++++++----------------- 2 files changed, 159 insertions(+), 143 deletions(-) diff --git a/devenv.lock b/devenv.lock index d606d21055..99f06fb280 100644 --- a/devenv.lock +++ b/devenv.lock @@ -51,10 +51,31 @@ "type": "github" } }, + "git-hooks": { + "inputs": { + "flake-compat": "flake-compat", + "gitignore": "gitignore", + "nixpkgs": [ + "nixpkgs" + ] + }, + "locked": { + "lastModified": 1737465171, + "owner": "cachix", + "repo": "git-hooks.nix", + "rev": "9364dc02281ce2d37a1f55b6e51f7c0f65a75f17", + "type": "github" + }, + "original": { + "owner": "cachix", + "repo": "git-hooks.nix", + "type": "github" + } + }, "gitignore": { "inputs": { "nixpkgs": [ - "pre-commit-hooks", + "git-hooks", "nixpkgs" ] }, @@ -101,34 +122,16 @@ "type": "github" } }, - "pre-commit-hooks": { - "inputs": { - "flake-compat": "flake-compat", - "gitignore": "gitignore", - "nixpkgs": [ - "nixpkgs" - ] - }, - "locked": { - "lastModified": 1735882644, - "owner": "cachix", - "repo": "pre-commit-hooks.nix", - "rev": "a5a961387e75ae44cc20f0a57ae463da5e959656", - "type": "github" - }, - "original": { - "owner": "cachix", - "repo": "pre-commit-hooks.nix", - "type": "github" - } - }, "root": { "inputs": { "devenv": "devenv", "fenix": "fenix", + "git-hooks": "git-hooks", "nixpkgs": "nixpkgs", "nixpkgs-stable": "nixpkgs-stable", - "pre-commit-hooks": "pre-commit-hooks" + "pre-commit-hooks": [ + "git-hooks" + ] } }, "rust-analyzer-src": { diff --git a/hugr-passes/src/non_local.rs b/hugr-passes/src/non_local.rs index 4a654ec217..ad62e8a99d 100644 --- a/hugr-passes/src/non_local.rs +++ b/hugr-passes/src/non_local.rs @@ -7,7 +7,7 @@ use std::{ }; //TODO Add `remove_nonlocal_edges` and `add_nonlocal_edges` functions -use itertools::Itertools as _; +use itertools::{Either, Itertools as _}; use hugr_core::{ builder::{ConditionalBuilder, Dataflow, DataflowSubContainer, HugrBuilder}, @@ -94,6 +94,22 @@ struct WorkItem { ty: Type, } +impl WorkItem { + pub fn go(self, hugr: &mut impl HugrMut, parent_source_map: &ParentSourceMap) { + let parent = hugr.get_parent(self.target.0).unwrap(); + let source = if hugr.get_parent(self.source.node()).unwrap() == parent { + self.source + } else { + parent_source_map + .get_source_in_parent(parent, self.source, &hugr) + .0 + }; + debug_assert_eq!(hugr.get_parent(source.node()), hugr.get_parent(self.target.0)); + hugr.disconnect(self.target.0, self.target.1); + hugr.connect(source.node(), source.source(), self.target.0, self.target.1); + } +} + #[derive(Clone, Default, Debug)] struct ParentSourceMap(HashMap>); @@ -298,10 +314,8 @@ impl<'a> ThreadState<'a> { .is_dataflow_block() .then_some((out_p.index(), succ_n)) }) { - let succ_needs = &self.needs[&succ_n]; - let succ_needs_source_indices = succ_needs - .iter() - .map(|(&w, _)| sources.iter().find_position(|(x, _)| x == &w).unwrap().0) + let succ_needs_source_indices = self.needs.get(succ_n) + .map(|(w, _)| sources.iter().find_position(|(x, _)| x == &w).unwrap().0) .collect_vec(); let succ_needs_tys = succ_needs_source_indices .iter() @@ -432,7 +446,7 @@ fn thread_sources( bb_needs_sources_map: &BBNeedsSourcesMap, ) -> (Vec, ParentSourceMap, Vec) { let mut state = ThreadState::new(bb_needs_sources_map); - for (&bb, sources) in bb_needs_sources_map { + for (bb, sources) in bb_needs_sources_map { let sources = sources.iter().map(|(&w, ty)| (w, ty.clone())).collect_vec(); match hugr.get_optype(bb).clone() { OpType::DFG(_) => state.do_dfg(hugr, bb, sources), @@ -460,82 +474,105 @@ fn mk_workitems( }) } -type BBNeedsSourcesMap = HashMap>; - #[derive(Debug, Default, Clone)] -struct BBNeedsSourcesMapBuilder(BBNeedsSourcesMap); +struct BBNeedsSourcesMap(BTreeMap>); -impl BBNeedsSourcesMapBuilder { - fn insert(&mut self, bb: Node, source: Wire, ty: Type) { - self.0.entry(bb).or_default().insert(source, ty); +struct NeedsSourcesMapIter<'a>(<&'a BTreeMap> as IntoIterator>::IntoIter); + +impl<'a> Iterator for NeedsSourcesMapIter<'a> { + type Item = (Node, &'a BTreeMap); + + fn next(&mut self) -> Option { + self.0.next().map(|(&n,bt)| (n,bt)) } +} - fn extend_parent_needs_for(&mut self, hugr: impl HugrView, child: Node) -> bool { - let hugr = &hugr; - let parent = hugr.get_parent(child).unwrap(); - let parent_needs = self - .0 - .get(&child) - .into_iter() - .flat_map(move |m| { - m.iter() - .filter(move |(w, _)| hugr.get_parent(w.node()).unwrap() != parent) - .map(|(&w, ty)| (w, ty.clone())) - }) - .collect_vec(); - let any = !parent_needs.is_empty(); - if any { - self.0.entry(parent).or_default().extend(parent_needs); +impl<'a> IntoIterator for &'a BBNeedsSourcesMap { + type Item = as Iterator>::Item; + + type IntoIter = NeedsSourcesMapIter<'a>; + + fn into_iter(self) -> Self::IntoIter { + NeedsSourcesMapIter((&self.0).into_iter()) + } +} + +impl BBNeedsSourcesMap { + fn insert(&mut self, node: Node, source: Wire, ty: Type) -> bool { + self.0.entry(node).or_default().insert(source, ty).is_none() + } + + fn get(&self, node: Node) -> impl Iterator + '_ { + match self.0.get(&node) { + Some(x) => Either::Left(x.iter().map(|(&w, t)| (w, t))), + None => Either::Right(iter::empty()) + } + } + + delegate! { + to self.0 { + fn keys(&self) -> impl Iterator; + } + } +} + +#[derive(Debug, Clone)] +struct BBNeedsSourcesMapBuilder { + hugr: H, + needs_sources: BBNeedsSourcesMap, +} + +impl BBNeedsSourcesMapBuilder { + fn new(hugr: H) -> Self { + Self { + hugr, + needs_sources:Default::default(), } - any } - fn finish(mut self, hugr: impl HugrView) -> BBNeedsSourcesMap { + fn insert(&mut self, mut parent: Node, source: Wire, ty: Type) { + let source_parent = self.hugr.get_parent(source.node()).unwrap(); + loop { + if source_parent == parent { + break; + } + if !self.needs_sources.insert(parent, source, ty.clone()) { + break; + } + let Some(parent_of_parent) = self.hugr.get_parent(parent) else { + break; + }; + parent = parent_of_parent + } + } + + fn finish(mut self) -> BBNeedsSourcesMap { { - let conds = self - .0 - .keys() - .copied() - .filter(|&n| hugr.get_optype(n).is_conditional()) - .collect_vec(); - for cond in conds { - let cases = hugr - .children(cond) - .filter(|&child| hugr.get_optype(child).is_case()) - .collect_vec(); - let all_needed: BTreeMap<_, _> = cases - .iter() - .flat_map(|&case| { - let case_needed = self.0.get(&case); - case_needed - .into_iter() - .flat_map(|m| m.iter().map(|(&w, ty)| (w, ty.clone()))) - }) - .collect(); - for case in cases { - let _ = self.0.insert(case, all_needed.clone()); + let conds = self.needs_sources.keys().copied().filter(|&n| self.hugr.get_optype(n).is_conditional()).collect_vec(); + for n in conds { + let n_needs = self.needs_sources.get(n).map(|(w,ty)| (w, ty.clone())).collect_vec(); + for case in self.hugr + .children(n) + .filter(|&child| self.hugr.get_optype(child).is_case()) { + for (w, ty) in n_needs.iter() { + self.needs_sources.insert(case, *w, ty.clone()); + } } } } { - let cfgs = self - .0 - .keys() - .copied() - .filter(|&n| hugr.get_optype(n).is_cfg() && self.0.contains_key(&n)) - .collect_vec(); - for cfg in cfgs { - let dfbs = hugr - .children(cfg) - .filter(|&child| hugr.get_optype(child).is_dataflow_block()) + let cfgs = self.needs_sources.keys().copied().filter(|&n| self.hugr.get_optype(n).is_cfg()).collect_vec(); + for n in cfgs { + let dfbs = self.hugr + .children(n) + .filter(|&child| self.hugr.get_optype(child).is_dataflow_block()) .collect_vec(); loop { let mut any_change = false; for &dfb in dfbs.iter() { - for succ_n in hugr.output_neighbours(dfb) { - for (w, ty) in self.0.get(&succ_n).cloned().unwrap_or_default() { - any_change |= - self.0.entry(dfb).or_default().insert(w, ty).is_none(); + for succ_n in self.hugr.output_neighbours(dfb) { + for (w, ty) in self.needs_sources.get(succ_n).map(|(w,ty)| (w, ty.clone())).collect_vec() { + any_change |= self.needs_sources.insert(dfb, w, ty.clone()); } } } @@ -545,20 +582,34 @@ impl BBNeedsSourcesMapBuilder { } } } + self.needs_sources + } +} - self.0 +fn build_needs_sources_map(hugr: impl HugrView, nonlocal_edges: &HashMap) -> BBNeedsSourcesMap { + let mut bnsm = BBNeedsSourcesMapBuilder::new(&hugr); + for workitem in nonlocal_edges.values() { + let parent = hugr.get_parent(workitem.target.0).unwrap(); + debug_assert!(hugr.get_parent(parent).is_some()); + bnsm.insert(parent, workitem.source, workitem.ty.clone()); } + bnsm.finish() } pub fn remove_nonlocal_edges( hugr: &mut impl HugrMut, root: Node, ) -> Result<(), NonLocalEdgesError> { + // First we collect all the non-local edges in the graph. We associate them to a WorkItem, which tracks: + // * the source of the non-local edge + // * the target of the non-local edge + // * the type of the non-local edge. Note that all non-local edges are + // value edges, so the type is well defined. let nonlocal_edges_map: HashMap = nonlocal_edges(&DescendantsGraph::::try_new(hugr, root)?) - .map(|target @ (node, inport)| { + .filter_map(|target @ (node, inport)| { let source = { - let (n, p) = hugr.single_linked_output(node, inport).unwrap(); + let (n, p) = hugr.single_linked_output(node, inport)?; Wire::new(n, p) }; debug_assert!( @@ -569,7 +620,7 @@ pub fn remove_nonlocal_edges( else { panic!("impossible") }; - (node, WorkItem { source, target, ty }) + Some((node, WorkItem { source, target, ty })) }) .collect(); @@ -577,45 +628,12 @@ pub fn remove_nonlocal_edges( return Ok(()); } - let bb_needs_sources_map = { - let nonlocal_sorted = { - let mut v = iter::successors(Some(vec![root]), |nodes| { - let children = nodes.iter().flat_map(|&n| hugr.children(n)).collect_vec(); - (!children.is_empty()).then_some(children) - }) - .flatten() - .filter_map(|n| nonlocal_edges_map.get(&n)) - .collect_vec(); - v.reverse(); - v - }; - let mut parent_set = HashSet::::new(); - // earlier items are deeper in the heirarchy - let mut parent_worklist = VecDeque::::new(); - let mut add_parent = |p, wl: &mut VecDeque<_>| { - if parent_set.insert(p) { - wl.push_back(p); - } - }; - let mut bnsm = BBNeedsSourcesMapBuilder::default(); - for workitem in nonlocal_sorted { - let parent = hugr.get_parent(workitem.target.0).unwrap(); - debug_assert!(hugr.get_parent(parent).is_some()); - bnsm.insert(parent, workitem.source, workitem.ty.clone()); - add_parent(parent, &mut parent_worklist); - } - - while let Some(bb_node) = parent_worklist.pop_front() { - let Some(parent) = hugr.get_parent(bb_node) else { - continue; - }; - if bnsm.extend_parent_needs_for(&hugr, bb_node) { - add_parent(parent, &mut parent_worklist); - } - } - bnsm.finish(&hugr) - }; + // We now compute the sources needed by each parent node. + // For a given non-local edge every intermediate node in the hierarchy + // between the source's parent and the target needs that source. + let bb_needs_sources_map = build_needs_sources_map(&hugr, &nonlocal_edges_map); + // TODO move this out-of-line #[cfg(debug_assertions)] { for (&n, wi) in nonlocal_edges_map.iter() { @@ -625,7 +643,7 @@ pub fn remove_nonlocal_edges( if hugr.get_parent(wi.source.node()).unwrap() == parent { break; } - assert!(bb_needs_sources_map[&parent].contains_key(&wi.source)); + assert!(bb_needs_sources_map.get(parent).find(|(w,_)| *w == wi.source).is_some()); m = parent; } } @@ -635,6 +653,11 @@ pub fn remove_nonlocal_edges( } } + // Here we mutate the HUGR; adding ports to parent nodes and their Input nodes. + // The result is: + // * parent_source_map: A map from parent and source to the wire that should substitute for that source in that parent. + // * worklist: a list of workitems. Each should be fulfilled by connecting the source, substituted through parent_source_map, to the target. + // * control_worklist: A list of control ports (i.e. 0th output port of DataflowBlock or TailLoop) that must be rewired. let (parent_source_map, worklist, control_worklist) = { let mut worklist = nonlocal_edges_map.into_values().collect_vec(); let (wl, psm, control_worklist) = thread_sources(hugr, &bb_needs_sources_map); @@ -643,17 +666,7 @@ pub fn remove_nonlocal_edges( }; for wi in worklist { - let parent = hugr.get_parent(wi.target.0).unwrap(); - let source = if hugr.get_parent(wi.source.node()).unwrap() == parent { - wi.source - } else { - parent_source_map - .get_source_in_parent(parent, wi.source, &hugr) - .0 - }; - debug_assert_eq!(hugr.get_parent(source.node()), hugr.get_parent(wi.target.0)); - hugr.disconnect(wi.target.0, wi.target.1); - hugr.connect(source.node(), source.source(), wi.target.0, wi.target.1); + wi.go(hugr, &parent_source_map) } for cwi in control_worklist {