diff --git a/tket-qsystem/Cargo.toml b/tket-qsystem/Cargo.toml index 0bd978448..61159ef09 100644 --- a/tket-qsystem/Cargo.toml +++ b/tket-qsystem/Cargo.toml @@ -45,6 +45,9 @@ typetag.workspace = true delegate.workspace = true indexmap.workspace = true anyhow = { workspace = true, optional = true } +tket1-passes = { path = "../tket1-passes" } +rayon.workspace = true +serde_json.workspace = true [dev-dependencies] cool_asserts.workspace = true diff --git a/tket-qsystem/src/lib.rs b/tket-qsystem/src/lib.rs index bdcf889cd..6208f18ee 100644 --- a/tket-qsystem/src/lib.rs +++ b/tket-qsystem/src/lib.rs @@ -24,8 +24,13 @@ use hugr_passes::{ use std::collections::HashSet; use lower_drops::LowerDropsPass; +use pytket::qsystem_decoder_config; +use rayon::iter::ParallelIterator; use replace_bools::{ReplaceBoolPass, ReplaceBoolPassError}; +use std::sync::Arc; use tket::TketOp; +use tket::serialize::pytket::{EncodeOptions, EncodedCircuit}; +use tket1_passes::{Tket1Circuit, Tket1Pass}; use extension::{ futures::FutureOpDef, @@ -152,6 +157,34 @@ impl QSystemPass { } // restore the entrypoint hugr.set_entrypoint(entrypoint); + + // Squash single qubit gates after conversion to the Qsystem gate set. + // Call the SquashRzPhasedX pass from pytket using the pass JSON + // https://docs.quantinuum.com/tket/api-docs/passes.html#pytket.passes.SquashRzPhasedX + let squash_pass_json_string = + serde_json::to_string(&tket_json_rs::pass::BasePass::StandardPass { + pass: tket_json_rs::pass::standard::StandardPass::SquashRzPhasedX, + }) + .unwrap(); + let mut encoded = + EncodedCircuit::new(hugr, EncodeOptions::new().with_subcircuits(true)).unwrap(); + encoded + .par_iter_mut() + .for_each(|(_region, serial_circuit)| { + let mut circuit_ptr = Tket1Circuit::from_serial_circuit(serial_circuit).unwrap(); + let my_circuit_json_before = serde_json::to_value(&serial_circuit).unwrap(); + println!("Circuit before ============================={_region}"); + println!("{}", my_circuit_json_before); + Tket1Pass::run_from_json(&squash_pass_json_string, &mut circuit_ptr).unwrap(); + *serial_circuit = circuit_ptr.to_serial_circuit().unwrap(); + + let my_circuit_json_after = serde_json::to_value(&serial_circuit).unwrap(); + println!("Circuit after =============================={_region}"); + println!("{}", my_circuit_json_after) + }); + encoded + .reassemble_inplace(hugr, Some(Arc::new(qsystem_decoder_config()))) + .unwrap(); Ok(()) } @@ -298,7 +331,7 @@ mod test { .finish_with_outputs([]) .unwrap(); - let (mut hugr, [call_node, h_node, f_node, rx_node, main_node]) = { + let (mut hugr, [call_node, h_node, f_node, rz_node, main_node]) = { let mut builder = mb .define_function( "main", @@ -327,7 +360,7 @@ mod test { .add_dataflow_op(QSystemOp::Rz, [qb, angle]) .unwrap() .outputs_arr(); - let rx_node = qb.node(); + let rz_node = qb.node(); // the Measure node will be removed. A Lazy Measure and two Future // Reads will be added. The Lazy Measure will be lifted and the @@ -342,7 +375,7 @@ mod test { .unwrap() .node(); let hugr = mb.finish_hugr().unwrap(); - (hugr, [call_node, h_node, f_node, rx_node, main_n]) + (hugr, [call_node, h_node, f_node, rz_node, main_n]) }; if set_entrypoint { // set the entrypoint to the main function @@ -362,7 +395,7 @@ mod test { }; assert!(get_pos(h_node) < get_pos(f_node)); assert!(get_pos(h_node) < get_pos(call_node)); - assert!(get_pos(rx_node) < get_pos(call_node)); + assert!(get_pos(rz_node) < get_pos(call_node)); for n in topo_sorted .iter()