Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 34 additions & 42 deletions tket-qsystem/tests/guppy_opt.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
//! Tests optimizing Guppy-generated programs.

use hugr::hugr::hugrmut::HugrMut;
use rayon::iter::ParallelIterator;
use smol_str::SmolStr;
use std::collections::HashMap;
use std::collections::{HashMap, VecDeque};
use std::fs;
use std::io::BufReader;
use std::path::Path;
Expand Down Expand Up @@ -46,19 +47,32 @@ fn load_guppy_example(path: &str) -> std::io::Result<Hugr> {
}

fn run_pytket(h: &mut Hugr) {
let circ = Circuit::new(h);
let mut encoded =
EncodedCircuit::new(&circ, EncodeOptions::new().with_subcircuits(true)).unwrap();

encoded
.par_iter_mut()
.for_each(|(_region, serial_circuit)| {
let mut circuit_ptr = Tket1Circuit::from_serial_circuit(serial_circuit).unwrap();
Tket1Pass::run_from_json(CLIFFORD_SIMP_STR, &mut circuit_ptr).unwrap();
*serial_circuit = circuit_ptr.to_serial_circuit().unwrap();
});

encoded.reassemble_inplace(circ.into_hugr(), None).unwrap();
let old_ep = h.entrypoint();
// We do not believe the above should remove any nodes that are valid Circuit parents,
// except DFGs (which are "transparent"); so include only DFGs not contained in some
// optimizable ancestor.
let mut queue = VecDeque::from([h.module_root()]);
while let Some(n) = queue.pop_front() {
h.set_entrypoint(n);

if let Ok(circ) = Circuit::try_new(&mut *h) {
let mut encoded =
EncodedCircuit::new(&circ, EncodeOptions::new().with_subcircuits(true)).unwrap();

encoded
.par_iter_mut()
.for_each(|(_region, serial_circuit)| {
let mut circuit_ptr =
Tket1Circuit::from_serial_circuit(serial_circuit).unwrap();
Tket1Pass::run_from_json(CLIFFORD_SIMP_STR, &mut circuit_ptr).unwrap();
*serial_circuit = circuit_ptr.to_serial_circuit().unwrap();
});

encoded.reassemble_inplace(circ.into_hugr(), None).unwrap();
}
queue.extend(h.children(n))
}
h.set_entrypoint(old_ep);
}

fn count_gates(h: &impl HugrView) -> HashMap<SmolStr, usize> {
Expand All @@ -80,22 +94,14 @@ fn count_gates(h: &impl HugrView) -> HashMap<SmolStr, usize> {

#[rstest]
#[case::nested_array("nested_array", None)]
#[should_panic = "xfail"]
#[case::angles("angles", Some(vec![
("tket.quantum.Rz", 2), ("tket.quantum.MeasureFree", 1), ("tket.quantum.H", 2), ("tket.quantum.QAlloc", 1)
]))]
#[should_panic = "PytketDecodeError { inner: DuplicatedParameter"]
#[case::angles("angles", None)]
#[should_panic = "xfail"]
#[case::simple_cx("simple_cx", Some(vec![
("tket.quantum.QAlloc", 2), ("tket.quantum.MeasureFree", 2),
]))]
#[should_panic = "xfail"]
#[case::nested("nested", Some(vec![
("tket.quantum.CZ", 6), ("tket.quantum.QAlloc", 3), ("tket.quantum.MeasureFree", 3), ("tket.quantum.H", 6)
]))]
#[should_panic = "xfail"]
#[case::ranges("ranges", Some(vec![
("tket.quantum.H", 8), ("tket.quantum.MeasureFree", 4), ("tket.quantum.QAlloc", 4), ("tket.quantum.CX", 6)
]))]
#[case::nested("nested", None)]
#[case::ranges("ranges", None)]
#[should_panic = "xfail"]
#[case::false_branch("false_branch", Some(vec![
("TKET1.tk1op", 1), ("tket.quantum.H", 1), ("tket.quantum.QAlloc", 1), ("tket.quantum.MeasureFree", 1)
Expand Down Expand Up @@ -127,23 +133,9 @@ fn optimize_flattened_guppy(#[case] name: &str, #[case] xfail: Option<Vec<(&str,
#[cfg_attr(miri, ignore)] // Opening files is not supported in (isolated) miri
fn optimize_guppy_ranges_array() {
// Demonstrates we can fully optimize the array operations in ranges
// (after control flow is flattened) if we play around with the entrypoint.
use hugr::algorithms::const_fold::ConstantFoldPass;
use hugr::hugr::hugrmut::HugrMut;
use tket::passes::BorrowSquashPass;
// (starting with a Hugr where only control flow has been flattened, not arrays)
let mut hugr = load_guppy_example("ranges/ranges.flat.array.hugr").unwrap();

let f = hugr
.children(hugr.module_root())
.find(|n| {
hugr.get_optype(*n)
.as_func_defn()
.is_some_and(|fd| fd.func_name() == "f")
})
.unwrap();
hugr.set_entrypoint(f);
ConstantFoldPass::default().run(&mut hugr).unwrap();
BorrowSquashPass::default().run(&mut hugr).unwrap();
NormalizeGuppy::default().run(&mut hugr).unwrap();
run_pytket(&mut hugr);
let expected_counts =
count_gates(&load_guppy_circuit("ranges", HugrFileType::Optimized).unwrap());
Expand Down
2 changes: 1 addition & 1 deletion tket/src/circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,7 @@ impl<T: HugrView> From<T> for Circuit<T> {

/// Checks if the passed hugr is a valid circuit,
/// and return [`CircuitError`] if not.
fn check_hugr<H: HugrView>(hugr: &H) -> Result<(), CircuitError<H::Node>> {
pub fn check_hugr<H: HugrView>(hugr: &H) -> Result<(), CircuitError<H::Node>> {
let optype = hugr.entrypoint_optype();
match optype {
// Dataflow nodes are always valid.
Expand Down
31 changes: 22 additions & 9 deletions tket/src/passes/guppy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ use hugr::algorithms::untuple::{UntupleError, UntupleRecursive};
use hugr::algorithms::{ComposablePass, RemoveDeadFuncsError, RemoveDeadFuncsPass, UntuplePass};
use hugr::hugr::hugrmut::HugrMut;
use hugr::hugr::patch::inline_dfg::InlineDFGError;
use hugr::Node;
use hugr::ops::Value;
use hugr::{IncomingPort, Node};

use crate::passes::BorrowSquashPass;

Expand Down Expand Up @@ -80,6 +81,17 @@ impl<H: HugrMut<Node = Node> + 'static> ComposablePass<H> for NormalizeGuppy {
type Error = NormalizeGuppyErrors;
type Result = ();
fn run(&self, hugr: &mut H) -> Result<Self::Result, Self::Error> {
let old_ep = hugr.entrypoint();
if self.dead_funcs && old_ep != hugr.module_root() {
// Remove everything not reachable from the entrypoint.
// (Could use visibility for module-entrypoint Hugrs, this might
// be appropriate if they are intended as libraries?)
RemoveDeadFuncsPass::default().run(hugr)?;
}
let old_ep = (old_ep != hugr.module_root()).then_some({
hugr.set_entrypoint(hugr.module_root());
old_ep
});
if self.simplify_cfgs {
NormalizeCFGPass::default().run(hugr)?;
}
Expand All @@ -88,13 +100,12 @@ impl<H: HugrMut<Node = Node> + 'static> ComposablePass<H> for NormalizeGuppy {
UntuplePass::new(UntupleRecursive::Recursive).run(hugr)?;
}
// Should propagate through untuple, so could do earlier, and must be before BorrowSquash
if self.constant_fold {
ConstantFoldPass::default().run(hugr)?;
}
// Only improves compilation speed, not affected by anything else
// until we start removing untaken branches
if self.dead_funcs {
RemoveDeadFuncsPass::default().run(hugr)?;
if let Some(ep) = old_ep.filter(|_| self.constant_fold) {
// For module-entrypoint Hugrs, we'd need to decide which functions are callable;
// the default is to assume none.
let no_inputs: [(IncomingPort, Value); 0] = [];
let cp = ConstantFoldPass::default().with_inputs(ep, no_inputs);
cp.run(hugr)?;
}
// Do earlier? Nothing creates DFGs
if self.inline_dfgs {
Expand All @@ -107,7 +118,9 @@ impl<H: HugrMut<Node = Node> + 'static> ComposablePass<H> for NormalizeGuppy {
.run(hugr)
.unwrap_or_else(|e| match e {});
}

if let Some(ep) = old_ep {
hugr.set_entrypoint(ep);
}
Ok(())
}
}
Expand Down
Loading