diff --git a/tket-qsystem/tests/guppy_opt.rs b/tket-qsystem/tests/guppy_opt.rs index d8fbe9b57..46ca29450 100644 --- a/tket-qsystem/tests/guppy_opt.rs +++ b/tket-qsystem/tests/guppy_opt.rs @@ -1,8 +1,9 @@ //! Tests optimizing Guppy-generated programs. +use hugr::hugr::hugrmut::HugrMut; use rayon::iter::ParallelIterator; use smol_str::SmolStr; -use std::collections::HashMap; +use std::collections::{HashMap, VecDeque}; use std::fs; use std::io::BufReader; use std::path::Path; @@ -46,19 +47,32 @@ fn load_guppy_example(path: &str) -> std::io::Result { } fn run_pytket(h: &mut Hugr) { - let circ = Circuit::new(h); - let mut encoded = - EncodedCircuit::new(&circ, EncodeOptions::new().with_subcircuits(true)).unwrap(); - - encoded - .par_iter_mut() - .for_each(|(_region, serial_circuit)| { - let mut circuit_ptr = Tket1Circuit::from_serial_circuit(serial_circuit).unwrap(); - Tket1Pass::run_from_json(CLIFFORD_SIMP_STR, &mut circuit_ptr).unwrap(); - *serial_circuit = circuit_ptr.to_serial_circuit().unwrap(); - }); - - encoded.reassemble_inplace(circ.into_hugr(), None).unwrap(); + let old_ep = h.entrypoint(); + // We do not believe the above should remove any nodes that are valid Circuit parents, + // except DFGs (which are "transparent"); so include only DFGs not contained in some + // optimizable ancestor. + let mut queue = VecDeque::from([h.module_root()]); + while let Some(n) = queue.pop_front() { + h.set_entrypoint(n); + + if let Ok(circ) = Circuit::try_new(&mut *h) { + let mut encoded = + EncodedCircuit::new(&circ, EncodeOptions::new().with_subcircuits(true)).unwrap(); + + encoded + .par_iter_mut() + .for_each(|(_region, serial_circuit)| { + let mut circuit_ptr = + Tket1Circuit::from_serial_circuit(serial_circuit).unwrap(); + Tket1Pass::run_from_json(CLIFFORD_SIMP_STR, &mut circuit_ptr).unwrap(); + *serial_circuit = circuit_ptr.to_serial_circuit().unwrap(); + }); + + encoded.reassemble_inplace(circ.into_hugr(), None).unwrap(); + } + queue.extend(h.children(n)) + } + h.set_entrypoint(old_ep); } fn count_gates(h: &impl HugrView) -> HashMap { @@ -80,22 +94,14 @@ fn count_gates(h: &impl HugrView) -> HashMap { #[rstest] #[case::nested_array("nested_array", None)] -#[should_panic = "xfail"] -#[case::angles("angles", Some(vec![ - ("tket.quantum.Rz", 2), ("tket.quantum.MeasureFree", 1), ("tket.quantum.H", 2), ("tket.quantum.QAlloc", 1) -]))] +#[should_panic = "PytketDecodeError { inner: DuplicatedParameter"] +#[case::angles("angles", None)] #[should_panic = "xfail"] #[case::simple_cx("simple_cx", Some(vec![ ("tket.quantum.QAlloc", 2), ("tket.quantum.MeasureFree", 2), ]))] -#[should_panic = "xfail"] -#[case::nested("nested", Some(vec![ - ("tket.quantum.CZ", 6), ("tket.quantum.QAlloc", 3), ("tket.quantum.MeasureFree", 3), ("tket.quantum.H", 6) -]))] -#[should_panic = "xfail"] -#[case::ranges("ranges", Some(vec![ - ("tket.quantum.H", 8), ("tket.quantum.MeasureFree", 4), ("tket.quantum.QAlloc", 4), ("tket.quantum.CX", 6) -]))] +#[case::nested("nested", None)] +#[case::ranges("ranges", None)] #[should_panic = "xfail"] #[case::false_branch("false_branch", Some(vec![ ("TKET1.tk1op", 1), ("tket.quantum.H", 1), ("tket.quantum.QAlloc", 1), ("tket.quantum.MeasureFree", 1) @@ -127,23 +133,9 @@ fn optimize_flattened_guppy(#[case] name: &str, #[case] xfail: Option From for Circuit { /// Checks if the passed hugr is a valid circuit, /// and return [`CircuitError`] if not. -fn check_hugr(hugr: &H) -> Result<(), CircuitError> { +pub fn check_hugr(hugr: &H) -> Result<(), CircuitError> { let optype = hugr.entrypoint_optype(); match optype { // Dataflow nodes are always valid. diff --git a/tket/src/passes/guppy.rs b/tket/src/passes/guppy.rs index e41ba8d0a..b40ff0e46 100644 --- a/tket/src/passes/guppy.rs +++ b/tket/src/passes/guppy.rs @@ -7,7 +7,8 @@ use hugr::algorithms::untuple::{UntupleError, UntupleRecursive}; use hugr::algorithms::{ComposablePass, RemoveDeadFuncsError, RemoveDeadFuncsPass, UntuplePass}; use hugr::hugr::hugrmut::HugrMut; use hugr::hugr::patch::inline_dfg::InlineDFGError; -use hugr::Node; +use hugr::ops::Value; +use hugr::{IncomingPort, Node}; use crate::passes::BorrowSquashPass; @@ -80,6 +81,17 @@ impl + 'static> ComposablePass for NormalizeGuppy { type Error = NormalizeGuppyErrors; type Result = (); fn run(&self, hugr: &mut H) -> Result { + let old_ep = hugr.entrypoint(); + if self.dead_funcs && old_ep != hugr.module_root() { + // Remove everything not reachable from the entrypoint. + // (Could use visibility for module-entrypoint Hugrs, this might + // be appropriate if they are intended as libraries?) + RemoveDeadFuncsPass::default().run(hugr)?; + } + let old_ep = (old_ep != hugr.module_root()).then_some({ + hugr.set_entrypoint(hugr.module_root()); + old_ep + }); if self.simplify_cfgs { NormalizeCFGPass::default().run(hugr)?; } @@ -88,13 +100,12 @@ impl + 'static> ComposablePass for NormalizeGuppy { UntuplePass::new(UntupleRecursive::Recursive).run(hugr)?; } // Should propagate through untuple, so could do earlier, and must be before BorrowSquash - if self.constant_fold { - ConstantFoldPass::default().run(hugr)?; - } - // Only improves compilation speed, not affected by anything else - // until we start removing untaken branches - if self.dead_funcs { - RemoveDeadFuncsPass::default().run(hugr)?; + if let Some(ep) = old_ep.filter(|_| self.constant_fold) { + // For module-entrypoint Hugrs, we'd need to decide which functions are callable; + // the default is to assume none. + let no_inputs: [(IncomingPort, Value); 0] = []; + let cp = ConstantFoldPass::default().with_inputs(ep, no_inputs); + cp.run(hugr)?; } // Do earlier? Nothing creates DFGs if self.inline_dfgs { @@ -107,7 +118,9 @@ impl + 'static> ComposablePass for NormalizeGuppy { .run(hugr) .unwrap_or_else(|e| match e {}); } - + if let Some(ep) = old_ep { + hugr.set_entrypoint(ep); + } Ok(()) } }