diff --git a/Cargo.lock b/Cargo.lock index 4020b94ef..e7852813f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -118,6 +118,15 @@ dependencies = [ "num-traits", ] +[[package]] +name = "approx" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cab112f0a86d568ea0e627cc1d6be74a1e9cd55214684db5561995f6dad897c6" +dependencies = [ + "num-traits", +] + [[package]] name = "arrayvec" version = "0.5.2" @@ -361,7 +370,7 @@ version = "0.18.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1a98d30140e3296250832bbaaff83b27dcd6fa3cc70fb6f1f3e5c9c0023b5317" dependencies = [ - "approx", + "approx 0.4.0", "num-traits", "serde", ] @@ -718,6 +727,86 @@ dependencies = [ "rayon", ] +[[package]] +name = "dashu" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85b3e5ac1e23ff1995ef05b912e2b012a8784506987a2651552db2c73fb3d7e0" +dependencies = [ + "dashu-base", + "dashu-float", + "dashu-int", + "dashu-macros", + "dashu-ratio", + "rustversion", +] + +[[package]] +name = "dashu-base" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c0b80bf6b85aa68c58ffea2ddb040109943049ce3fbdf4385d0380aef08ef289" + +[[package]] +name = "dashu-float" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85078445a8dbd2e1bd21f04a816f352db8d333643f0c9b78ca7c3d1df71063e7" +dependencies = [ + "dashu-base", + "dashu-int", + "num-modular", + "num-order", + "num-traits", + "rustversion", + "static_assertions", +] + +[[package]] +name = "dashu-int" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ee99d08031ca34a4d044efbbb21dff9b8c54bb9d8c82a189187c0651ffdb9fbf" +dependencies = [ + "cfg-if", + "dashu-base", + "num-modular", + "num-order", + "num-traits", + "rustversion", + "static_assertions", +] + +[[package]] +name = "dashu-macros" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93381c3ef6366766f6e9ed9cf09e4ef9dec69499baf04f0c60e70d653cf0ab10" +dependencies = [ + "dashu-base", + "dashu-float", + "dashu-int", + "dashu-ratio", + "paste", + "proc-macro2", + "quote", + "rustversion", +] + +[[package]] +name = "dashu-ratio" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "47e33b04dd7ce1ccf8a02a69d3419e354f2bbfdf4eb911a0b7465487248764c9" +dependencies = [ + "dashu-base", + "dashu-float", + "dashu-int", + "num-modular", + "num-order", + "rustversion", +] + [[package]] name = "delegate" version = "0.10.0" @@ -866,6 +955,29 @@ dependencies = [ "syn 2.0.111", ] +[[package]] +name = "env_filter" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1bf3c259d255ca70051b30e2e95b5446cdb8949ac4cd22c0d7fd634d89f568e2" +dependencies = [ + "log", + "regex", +] + +[[package]] +name = "env_logger" +version = "0.11.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "13c863f0904021b108aa8b2f55046443e6b1ebde8fd4a15c399893aae4fa069f" +dependencies = [ + "anstream", + "anstyle", + "env_filter", + "jiff", + "log", +] + [[package]] name = "equivalent" version = "1.0.2" @@ -1387,6 +1499,30 @@ version = "1.0.16" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7ee5b5339afb4c41626dde77b7a611bd4f2c202b897852b4bcf5d03eddc61010" +[[package]] +name = "jiff" +version = "0.2.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e67e8da4c49d6d9909fe03361f9b620f58898859f5c7aded68351e85e71ecf50" +dependencies = [ + "jiff-static", + "log", + "portable-atomic", + "portable-atomic-util", + "serde_core", +] + +[[package]] +name = "jiff-static" +version = "0.2.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e0c84ee7f197eca9a86c6fd6cb771e55eb991632f15f2bc3ca6ec838929e6e78" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.111", +] + [[package]] name = "jobserver" version = "0.1.34" @@ -1463,6 +1599,16 @@ version = "0.4.29" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5e5032e24019045c762d3c0f28f5b6b8bbf38563a65908389bf7978758920897" +[[package]] +name = "matrixmultiply" +version = "0.3.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a06de3016e9fae57a36fd14dba131fccf49f74b40b7fbdb472f96e361ec71a08" +dependencies = [ + "autocfg", + "rawpointer", +] + [[package]] name = "memchr" version = "2.7.6" @@ -1484,6 +1630,33 @@ version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a" +[[package]] +name = "nalgebra" +version = "0.33.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "26aecdf64b707efd1310e3544d709c5c0ac61c13756046aaaba41be5c4f66a3b" +dependencies = [ + "approx 0.5.1", + "matrixmultiply", + "nalgebra-macros", + "num-complex", + "num-rational", + "num-traits", + "simba", + "typenum", +] + +[[package]] +name = "nalgebra-macros" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "254a5372af8fc138e36684761d3c0cdb758a4410e938babcff1c860ce14ddbfc" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.111", +] + [[package]] name = "nom" version = "7.1.3" @@ -1503,6 +1676,20 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "num" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "35bd024e8b2ff75562e5f34e7f4905839deb4b22955ef5e73d2fea1b9813cb23" +dependencies = [ + "num-bigint", + "num-complex", + "num-integer", + "num-iter", + "num-rational", + "num-traits", +] + [[package]] name = "num-bigint" version = "0.4.6" @@ -1513,6 +1700,15 @@ dependencies = [ "num-traits", ] +[[package]] +name = "num-complex" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "73f88a1307638156682bada9d7604135552957b7818057dcef22705b4d509495" +dependencies = [ + "num-traits", +] + [[package]] name = "num-conv" version = "0.1.0" @@ -1528,6 +1724,32 @@ dependencies = [ "num-traits", ] +[[package]] +name = "num-iter" +version = "0.1.45" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1429034a0490724d0075ebb2bc9e875d6503c3cf69e235a8941aa757d83ef5bf" +dependencies = [ + "autocfg", + "num-integer", + "num-traits", +] + +[[package]] +name = "num-modular" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "17bb261bf36fa7d83f4c294f834e91256769097b3cb505d44831e0a179ac647f" + +[[package]] +name = "num-order" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "537b596b97c40fcf8056d153049eb22f481c17ebce72a513ec9286e4986d1bb6" +dependencies = [ + "num-modular", +] + [[package]] name = "num-rational" version = "0.4.2" @@ -1583,7 +1805,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7f4779c6901a562440c3786d08192c6fbda7c1c2060edd10006b05ee35d10f2d" dependencies = [ "num-traits", - "rand", + "rand 0.8.5", "serde", ] @@ -1756,6 +1978,15 @@ version = "1.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f59e70c4aef1e55797c2e8fd94a4f2a973fc972cfde0e0b05f683667b0cd39dd" +[[package]] +name = "portable-atomic-util" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d8a2f0d8d040d7848a709caf78912debcc3f33ee4b3cac47d73d1e1069e83507" +dependencies = [ + "portable-atomic", +] + [[package]] name = "portgraph" version = "0.8.0" @@ -1809,6 +2040,15 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "439ee305def115ba05938db6eb1644ff94165c5ab5e9420d1c1bcedbba909391" +[[package]] +name = "ppv-lite86" +version = "0.2.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85eae3c4ed2f50dcfe72643da4befc30deadb458a9b590d720cde2f2b1e97da9" +dependencies = [ + "zerocopy", +] + [[package]] name = "pretty" version = "0.12.5" @@ -1980,10 +2220,30 @@ version = "0.8.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" dependencies = [ - "rand_core", + "rand_core 0.6.4", "serde", ] +[[package]] +name = "rand" +version = "0.9.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6db2770f06117d490610c7488547d543617b21bfa07796d7a12f6f1bd53850d1" +dependencies = [ + "rand_chacha", + "rand_core 0.9.3", +] + +[[package]] +name = "rand_chacha" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d3022b5f1df60f26e1ffddd6c66e8aa15de382ae63b3a0c1bfc0e4d3e3f325cb" +dependencies = [ + "ppv-lite86", + "rand_core 0.9.3", +] + [[package]] name = "rand_core" version = "0.6.4" @@ -1993,6 +2253,21 @@ dependencies = [ "serde", ] +[[package]] +name = "rand_core" +version = "0.9.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "99d9a13982dcf210057a8a78572b2217b667c3beacbf3a0d8b454f6f82837d38" +dependencies = [ + "getrandom", +] + +[[package]] +name = "rawpointer" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3" + [[package]] name = "rayon" version = "1.11.0" @@ -2114,6 +2389,26 @@ dependencies = [ "serde", ] +[[package]] +name = "rsgridsynth" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2d7a6206ba19a3dbe199d9931a8630ce8b0f0c6cc393d81c700b8f2c8bd312fc" +dependencies = [ + "clap", + "dashu", + "dashu-base", + "dashu-float", + "dashu-int", + "env_logger", + "log", + "nalgebra", + "num", + "num-traits", + "once_cell", + "rand 0.9.2", +] + [[package]] name = "rstest" version = "0.26.1" @@ -2189,6 +2484,15 @@ version = "1.0.21" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "62049b2877bf12821e8f9ad256ee38fdc31db7387ec2d3b3f403024de2034aea" +[[package]] +name = "safe_arch" +version = "0.7.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "96b02de82ddbe1b636e6170c21be622223aea188ef2e139be0a5b219ec215323" +dependencies = [ + "bytemuck", +] + [[package]] name = "same-file" version = "1.0.6" @@ -2391,6 +2695,19 @@ version = "1.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" +[[package]] +name = "simba" +version = "0.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c99284beb21666094ba2b75bbceda012e610f5479dfcc2d6e2426f53197ffd95" +dependencies = [ + "approx 0.5.1", + "num-complex", + "num-traits", + "paste", + "wide", +] + [[package]] name = "similar" version = "2.7.0" @@ -2685,6 +3002,7 @@ dependencies = [ "priority-queue", "rayon", "rmp-serde", + "rsgridsynth", "rstest", "serde", "serde_json", @@ -3089,6 +3407,16 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "wide" +version = "0.7.33" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0ce5da8ecb62bcd8ec8b7ea19f69a51275e91299be594ea5cc6ef7819e16cd03" +dependencies = [ + "bytemuck", + "safe_arch", +] + [[package]] name = "winapi" version = "0.3.9" diff --git a/Cargo.toml b/Cargo.toml index 4538004cf..821675f46 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -104,7 +104,4 @@ cool_asserts = "2.0.4" zstd = "0.13.3" anyhow = "1.0.100" num-rational = "0.4.2" - -[profile.release.package.tket-py] -# Some configurations to reduce the size of tket wheels -strip = true +rsgridsynth = "0.2.0" diff --git a/tket-py/Cargo.toml b/tket-py/Cargo.toml index ed6303d6d..7d4e535c8 100644 --- a/tket-py/Cargo.toml +++ b/tket-py/Cargo.toml @@ -26,7 +26,6 @@ tket = { path = "../tket", version = "0.16.0", features = [ ] } tket-qsystem = { path = "../tket-qsystem", version = "0.22.0" } tket1-passes = { path = "../tket1-passes", version = "0.0.0" } - derive_more = { workspace = true, features = ["into", "from"] } hugr = { workspace = true } itertools = { workspace = true } diff --git a/tket-py/src/passes.rs b/tket-py/src/passes.rs index 97920e762..1eba2238f 100644 --- a/tket-py/src/passes.rs +++ b/tket-py/src/passes.rs @@ -1,6 +1,7 @@ //! Passes for optimising circuits. pub mod chunks; +pub mod gridsynth; pub mod tket1; use std::{cmp::min, convert::TryInto, fs, num::NonZeroUsize, path::PathBuf}; @@ -29,6 +30,7 @@ pub fn module(py: Python<'_>) -> PyResult> { m.add_function(wrap_pyfunction!(normalize_guppy, &m)?)?; m.add_class::()?; m.add_function(wrap_pyfunction!(self::chunks::chunks, &m)?)?; + m.add_function(wrap_pyfunction!(self::gridsynth::gridsynth, &m)?)?; m.add_function(wrap_pyfunction!(self::tket1::tket1_pass, &m)?)?; m.add("PullForwardError", py.get_type::())?; m.add("TK1PassError", py.get_type::())?; diff --git a/tket-py/src/passes/gridsynth.rs b/tket-py/src/passes/gridsynth.rs new file mode 100644 index 000000000..d1de5bac0 --- /dev/null +++ b/tket-py/src/passes/gridsynth.rs @@ -0,0 +1,32 @@ +//! Bindings to allow users to access the gridsynth pass from Python. +//! The definitions here should be reflected in the +//! `tket-py/tket/_tket/passes.pyi` type stubs file +use crate::circuit::CircuitType; +use crate::circuit::try_with_circ; +use crate::utils::{ConvertPyErr, create_py_exception}; +use pyo3::prelude::*; +use tket::passes::gridsynth::apply_gridsynth_pass; + +create_py_exception!( + tket::passes::gridsynth::GridsynthError, + PyGridsynthError, + "Errors from the gridsynth pass." +); + +/// Binding to a python function called gridsynth that runs the rust function called +/// apply_gridsynth pass behind the scenes +#[pyfunction] +pub fn gridsynth<'py>( + circ: &Bound<'py, PyAny>, + epsilon: f64, + simplify: bool, +) -> PyResult> { + let py = circ.py(); + + try_with_circ(circ, |mut circ: tket::Circuit, typ: CircuitType| { + apply_gridsynth_pass(circ.hugr_mut(), epsilon, simplify).convert_pyerrs()?; + + let circ = typ.convert(py, circ)?; + PyResult::Ok(circ) + }) +} diff --git a/tket-py/tket/_tket/passes.pyi b/tket-py/tket/_tket/passes.pyi index 285033945..448685ba9 100644 --- a/tket-py/tket/_tket/passes.pyi +++ b/tket-py/tket/_tket/passes.pyi @@ -105,3 +105,15 @@ def tket1_pass( circuit-like regions, and optimise them too. nested inside other subregions of the circuit. """ + +def gridsynth(hugr: CircuitClass, epsilon: float, simplify: bool) -> CircuitClass: + """Runs a pass applying the gridsynth algorithm to all Rz gates in a HUGR, + which decomposes them into the Clifford + T basis. + + Parameters: + - hugr: the hugr to run the pass on. + - epsilon: the precision of the gridsynth decomposition + - simplify: if `True`, each sequence of gridsynth gates is compressed into + a sequence of H*T and H*Tdg gates, sandwiched by Clifford gates. This sequence + always has a smaller number of S and H gates, and the same number of T+Tdg gates. + """ diff --git a/tket-py/tket/passes.py b/tket-py/tket/passes.py index 73b1beca9..2df63343c 100644 --- a/tket-py/tket/passes.py +++ b/tket-py/tket/passes.py @@ -31,6 +31,7 @@ tket1_pass, normalize_guppy, PullForwardError, + gridsynth, ) __all__ = [ @@ -176,3 +177,48 @@ def _normalize(self, hugr: Hugr, inplace: bool) -> PassResult: ) new_hugr = Hugr.from_str(opt_program.to_str()) return PassResult.for_pass(self, hugr=new_hugr, inplace=inplace, result=None) + + +@dataclass +class Gridsynth(ComposablePass): + epsilon: float + simplify: bool = True + + """Apply the gridsynth algorithm to all Rz gates in the Hugr. + + This includes a NormalizeGuppy pass with all of the constituent passes applied + before applying gridsynth to standardise the + format. + + Parameters: + - epsilon: The allowable error tolerance. + - simplify: If `True`, each sequence of gridsynth gates is compressed into + a sequence of H*T and H*Tdg gates, sandwiched by Clifford gates. This sequence + always has a smaller number of S and H gates, and the same number of T+Tdg gates. + """ + + # TO DO: make the NormalizeGuppy pass optional, in case it is already run + # before Gridsynth. Need to warn users, at least in docs that if NormalizeGuppy + # is not run first then Gridsynth is likely to fail. Maybe issue the warning if + # the option to run NormalizeGuppy is set to False. The option would be specified + # as a field of the dataclass (would also need to add @dataclass decorator) + # like for NormalizeGuppy above + def run(self, hugr: Hugr, *, inplace: bool = True) -> PassResult: + # inplace option does nothing for now but I retain for consistency of + # API with other passes + return implement_pass_run( + self, + hugr=hugr, + inplace=inplace, + copy_call=lambda h: self._apply_gridsynth_pass( + hugr, self.epsilon, self.simplify, inplace + ), + ) + + def _apply_gridsynth_pass( + self, hugr: Hugr, epsilon: float, simplify: bool, inplace: bool + ) -> PassResult: + compiler_state: Tk2Circuit = Tk2Circuit.from_bytes(hugr.to_bytes()) + program = gridsynth(compiler_state, epsilon, simplify) + new_hugr = Hugr.from_str(program.to_str()) + return PassResult.for_pass(self, hugr=new_hugr, inplace=inplace, result=None) diff --git a/tket/Cargo.toml b/tket/Cargo.toml index 98689d69b..88b67f5fb 100644 --- a/tket/Cargo.toml +++ b/tket/Cargo.toml @@ -77,6 +77,7 @@ pest_derive = { workspace = true } zstd = { workspace = true, optional = true } anyhow = { workspace = true, optional = true } num-rational = { workspace = true } +rsgridsynth = { workspace = true } [dev-dependencies] diff --git a/tket/src/passes.rs b/tket/src/passes.rs index 9594cbc93..9640af3a5 100644 --- a/tket/src/passes.rs +++ b/tket/src/passes.rs @@ -16,6 +16,8 @@ pub use guppy::NormalizeGuppy; pub mod pytket; pub use pytket::lower_to_pytket; +pub mod gridsynth; + pub mod tuple_unpack; #[expect(deprecated)] pub use tuple_unpack::find_tuple_unpack_rewrites; diff --git a/tket/src/passes/gridsynth.rs b/tket/src/passes/gridsynth.rs new file mode 100644 index 000000000..0e26d7a7c --- /dev/null +++ b/tket/src/passes/gridsynth.rs @@ -0,0 +1,519 @@ +//! A pass that applies the gridsynth algorithm to all Rz gates in a HUGR. +/// The pass introduced here assumes that (1) all functions have been inlined and +/// (2) that NormalizeGuppy has been run. It expects that the following is guaranteed: +/// * That every Const node is immediately connected to a LoadConst node. +/// * That every constant is used in a single place (i.e. that there's a single output +/// of each LoadConst). +/// * That we can find the Const node connected to an Rz node by first following the +/// input port 1 of the Rz node (i.e. the angle argument) and then iteratively +/// following the input on port 0 until we reach a Const node. +use std::collections::HashMap; + +use crate::TketOp; +use crate::extension::rotation::ConstRotation; +use crate::hugr::HugrView; +use crate::hugr::Node; +use crate::hugr::NodeIndex; +use crate::hugr::hugr::{ValidationError, hugrmut::HugrMut}; +use crate::passes::guppy::{NormalizeGuppy, NormalizeGuppyErrors}; +use crate::{Hugr, hugr, op_matches}; +use hugr::algorithms::ComposablePass; +use hugr::std_extensions::arithmetic::float_types::ConstF64; +use hugr_core::OutgoingPort; +use rsgridsynth::config::config_from_theta_epsilon; +use rsgridsynth::gridsynth::gridsynth_gates; + +/// Errors that can occur during the Gridsynth pass due to acting on a hugr that +/// goes beyond the scope of what the pass can optimise. The most likely reasons for this +/// are that the Rz angle is defined at runtime or that the NormalizeGuppy pass is unable +/// to standardise the form of the HUGR enough. Issues may occur when the constant node +/// providing the angle crosses function boundaries or if the control flow is especially +/// complicated +#[derive(derive_more::Error, Debug, derive_more::Display, derive_more::From)] +pub enum GridsynthError { + /// Error during the NormalizeGuppy pass + NormalizeGuppyErrors(NormalizeGuppyErrors), + /// Error during validation of the HUGR + ValidationError(ValidationError), +} +/// Garbage collector for cleaning up constant nodes providing angles to the Rz gates and +/// all nodes on the path to them. +struct GarbageCollector { + references: HashMap, // key: node index (of Const node containing angle), + // value: reference counter for that node + path: HashMap>, // key: node index (of Const node containing angle), + // value: the nodes leading up to the constant node and the constant + // node itself +} // CONCERN: I am concerned that this approach may not clean up properly if the Guppy user +// has redundant calls of the constant (eg, has used it to define another constant but then +// not used that constant). + +impl GarbageCollector { + /// Add references to constant node + fn add_references(&mut self, node: Node, increment: usize) { + // if reference not in references add it with the default value 1, else increment count + let count = self.references.entry(node.index()).or_insert(1); + *count += increment; + } + + /// Remove reference to a constant node + fn remove_references(&mut self, node: Node, increment: usize) { + // reduce reference count + let count = self.references.get_mut(&node.index()).unwrap(); + *count -= increment; + } + + /// Infer how many references there are to the angle-containing Const node + /// given the corresponding `load_const_node` + fn infer_references_to_angle( + &mut self, + hugr: &mut Hugr, + load_const_node: Node, + const_node: Node, + ) { + let references_collection: Vec<_> = hugr.node_outputs(load_const_node).collect(); + let num_references = references_collection.len(); + // if reference not in references add it with the default value num_references, else do nothing + self.references + .entry(const_node.index()) + .or_insert(num_references); + } + + /// If there are no references remaining to const_node, remove it and the nodes leading to it + fn collect(&mut self, hugr: &mut Hugr, const_node: Node) { + let node_index = &const_node.index(); + if self.references[node_index] == 0 { + let path: Vec = self.path.get(node_index).unwrap().to_vec(); + for node in path { + hugr.remove_node(node); + } + } + } +} + +/// Find the nodes for the Rz gates. +fn find_rzs(hugr: &mut Hugr) -> Option> { + let mut rz_nodes: Vec = Vec::new(); + for node in hugr.nodes() { + let op_type = hugr.get_optype(node); + if op_matches(op_type, TketOp::Rz) { + rz_nodes.push(node); + } + } + + // if there are rz_nodes: + if !(rz_nodes.is_empty()) { + return Some(rz_nodes); + } + None +} + +/// Find the output port and node linked to the input specified by `port_idx` for `node` +fn find_single_linked_output_by_index( + hugr: &mut Hugr, + node: Node, + port_idx: usize, +) -> (Node, OutgoingPort) { + let ports = hugr.node_inputs(node); + let collected_ports: Vec<_> = ports.collect(); + + hugr.single_linked_output(node, collected_ports[port_idx]) + .expect("Not yet set-up to handle cases where there are no previous nodes") +} + +/// Find the constant node containing the angle to be inputted to the Rz gate. +/// It is assumed that `hugr` has had the NormalizeGuppy passes applied to it +/// prior to being applied. This function also cleans up behind itself removing +/// everything on the path to the `angle_node` but not the `angle_node` itself, +/// which is still needed. +fn find_angle_node( + hugr: &mut Hugr, + rz_node: Node, + garbage_collector: &mut GarbageCollector, +) -> Node { + // Find linked ports to the rz port where the angle will be inputted + // the port offset of the angle is known to be 1 for the rz gate. + let (mut prev_node, _) = find_single_linked_output_by_index(hugr, rz_node, 1); + + // As all of the NormalizeGuppy passes have been run on the `hugr` before it enters this function, + // and these passes include constant folding, we can assume that we can follow the 0th ports back + // to a constant node where the angle is defined. + let mut path = Vec::new(); // The nodes leading up to the angle_node and the angle_node + // itself + loop { + let (current_node, _) = find_single_linked_output_by_index(hugr, prev_node, 0); + let op_type = hugr.get_optype(current_node); + path.push(current_node); + + garbage_collector.add_references(current_node, 1); + + if op_type.is_const() { + let load_const_node = prev_node; + let angle_node = current_node; + // Add references to angle node if this has not already been done + garbage_collector.infer_references_to_angle(hugr, load_const_node, angle_node); + // Remove one reference to reflect the fact that we are about to use the angle node + garbage_collector.remove_references(angle_node, 1); + // Let garbage collector know what nodes led to the angle node + garbage_collector + .path + .entry(angle_node.index()) + .or_insert(path); + return angle_node; + } + + prev_node = current_node; + } +} + +fn find_angle(hugr: &mut Hugr, rz_node: Node, garbage_collector: &mut GarbageCollector) -> f64 { + let angle_node = find_angle_node(hugr, rz_node, garbage_collector); + let op_type = hugr.get_optype(angle_node); + let angle_const = op_type.as_const().unwrap(); + let angle_val = &angle_const.value; + + // Handling likely angle formats. Panic if angle is not one of the anticipated formats + let angle = if let Some(rot) = angle_val.get_custom_value::() { + rot.to_radians() + } else if let Some(fl) = angle_val.get_custom_value::() { + let half_turns = fl.value(); + ConstRotation::new(half_turns).unwrap().to_radians() + } else { + panic!("Angle not specified as ConstRotation or ConstF64") + }; + + // We now have what we need to know from the angle node and can remove it from the HUGR if + // no further references remain to it + garbage_collector.collect(hugr, angle_node); + + angle +} + +/// Call gridsynth on the angle of the Rz gate node provided. +/// If `simplify` is `true`, the sequence of gridsynth gates is compressed into +/// a sequence of H*T and H*Tdg gates, sandwiched by Clifford gates. This sequence +/// always has a smaller number of S and H gates, and the same number of T+Tdg gates. +fn get_gridsynth_gates( + hugr: &mut Hugr, + rz_node: Node, + epsilon: f64, + simplify: bool, + garbage_collector: &mut GarbageCollector, +) -> String { + let theta = find_angle(hugr, rz_node, garbage_collector); + let seed = 1234; + let verbose = false; + let up_to_phase = false; + let mut gridsynth_config = + config_from_theta_epsilon(theta, epsilon, seed, verbose, up_to_phase); + let mut gate_sequence = gridsynth_gates(&mut gridsynth_config).gates; + + if simplify { + let n = gate_sequence.len(); + let mut normal_form_reached = false; + while !normal_form_reached { + // Not the most efficient, but it's easiest to reach the normal form by doing + // string rewrites. + // TODO: Can be done with Regex, preferably by providing all LHS to the Regex + // so they are all matched at once; then we replace each one accordingly. + // NOTE: Ignoring global phase factors + let new_gate_sequence = gate_sequence + // Cancellation rules + .replacen("ZZ", "", n) + .replacen("XX", "", n) + .replacen("HH", "", n) + .replacen("SS", "Z", n) + .replacen("TT", "S", n) + .replacen("DD", "SZ", n) + .replacen("TD", "", n) + .replacen("DT", "", n) + // Rules to push Paulis to the right + .replacen("ZS", "SZ", n) + .replacen("ZT", "TZ", n) + .replacen("ZD", "DZ", n) + .replacen("XS", "SZX", n) + .replacen("XT", "DX", n) + .replacen("XD", "TX", n) + .replacen("ZH", "HX", n) + .replacen("XH", "HZ", n) + .replacen("XZ", "ZX", n) + // Interaction of H and S (reduces number of H) + .replacen("HSH", "SHSX", n) + // Interaction of S and T (reduces number of S) + .replacen("DS", "T", n) + .replacen("SD", "T", n) + .replacen("TS", "DZ", n) + .replacen("ST", "DZ", n); + // Stop when no more changes are possible + normal_form_reached = new_gate_sequence == gate_sequence; + gate_sequence = new_gate_sequence; + } + } + gate_sequence +} + +fn replace_rz_with_gridsynth_output( + hugr: &mut Hugr, + rz_node: Node, + gates: &str, +) -> Result<(), GridsynthError> { + // Remove the W i.e. the exp(i*pi/4) global phases + let gate_sequence = gates.replacen("W", "", gates.len()); + // Add the nodes of the gridsynth sequence into the HUGR + let gridsynth_nodes: Vec = gate_sequence + .chars() + .map(|gate| match gate { + 'H' => TketOp::H, + 'S' => TketOp::S, + 'T' => TketOp::T, + 'D' => TketOp::Tdg, + 'X' => TketOp::X, + 'Z' => TketOp::Z, + _ => panic!("The gate {gate} is not supported"), + }) + .map(|op| hugr.add_node_after(rz_node, op)) // No connections just yet + .collect(); // Force the nodes to actually be added + + // Get the node that connects to the input of the Rz gate + let rz_input_port = hugr.node_inputs(rz_node).next().unwrap(); + let (mut prev_node, mut prev_port) = hugr.single_linked_output(rz_node, rz_input_port).unwrap(); + // Get the node that connects to the output of the Rz gate + let rz_output_port = hugr.node_outputs(rz_node).next().unwrap(); + let (next_node, next_port) = hugr.single_linked_input(rz_node, rz_output_port).unwrap(); + // We have now inferred what we need to know from the Rz node; we can remove it + hugr.remove_node(rz_node); + + // Connect the gridsynth nodes + for current_node in gridsynth_nodes { + // Connect the current node with the previous node + let current_port = hugr.node_inputs(current_node).next().unwrap(); + hugr.connect(prev_node, prev_port, current_node, current_port); + // Update who is the prev_node + prev_node = current_node; + prev_port = hugr.node_outputs(prev_node).next().unwrap(); + } + // Finally, connect the last gridsynth node with the node that came after the Rz + hugr.connect(prev_node, prev_port, next_node, next_port); + // hugr.validate()?; + Ok(()) +} + +/// Replace an Rz gate with the corresponding gates outputted by gridsynth. +/// If `simplify` is `true`, the sequence of gridsynth gates is compressed into +/// a sequence of H*T and H*Tdg gates, sandwiched by Clifford gates. This sequence +/// always has a smaller number of S and H gates, and the same number of T+Tdg gates. +pub fn apply_gridsynth_pass( + hugr: &mut Hugr, + epsilon: f64, + simplify: bool, +) -> Result<(), GridsynthError> { + // Running passes to convert HUGR to standard form + NormalizeGuppy::default() + .simplify_cfgs(true) + .remove_tuple_untuple(true) + .constant_folding(true) + .remove_dead_funcs(true) + .inline_dfgs(true) + .run(hugr)?; + + let rz_nodes = find_rzs(hugr).unwrap(); + let mut garbage_collector = GarbageCollector { + references: HashMap::new(), + path: HashMap::new(), + }; + for node in rz_nodes { + let gates = get_gridsynth_gates(hugr, node, epsilon, simplify, &mut garbage_collector); + replace_rz_with_gridsynth_output(hugr, node, &gates)?; + } + Ok(()) +} + +/// Example error. +#[derive(Debug, derive_more::Display, derive_more::Error)] +#[display("Example error: {message}")] +pub struct ExampleError { + message: String, +} + +// The following tests only check if any errors occur because Selene is challenging to access from the rust +// API. However, Selene simulations in Python versions of the HUGRs in these tests and more complicated HUGRS are +// available at https://github.com/Quantinuum/gridsynth_guppy_demo.git +#[cfg(test)] +mod tests { + use super::*; + + use crate::extension::bool::bool_type; + use crate::extension::rotation::ConstRotation; + use crate::hugr::builder::{Container, DFGBuilder, Dataflow, HugrBuilder}; + use crate::hugr::extension::prelude::qb_t; + use crate::hugr::ops::Value; + use crate::hugr::types::Signature; + use hugr::builder::DataflowHugr; + + fn build_rz_only_circ() -> (Hugr, Node) { + let theta = 0.64; + let qb_row = vec![qb_t(); 1]; + let mut h = DFGBuilder::new(Signature::new(qb_row.clone(), qb_row)).unwrap(); + let [q_in] = h.input_wires_arr(); + + let constant = h.add_constant(Value::extension( + ConstRotation::from_radians(theta).unwrap(), + )); + let loaded_const = h.load_const(&constant); + let rz = h.add_dataflow_op(TketOp::Rz, [q_in, loaded_const]).unwrap(); + let _ = h.set_outputs(rz.outputs()); + let mut circ = h.finish_hugr().unwrap(); + circ.validate().unwrap_or_else(|e| panic!("{e}")); + let rz_nodes = find_rzs(&mut circ).unwrap(); + let rz_node = rz_nodes[0]; + (circ, rz_node) + } + + fn build_non_trivial_circ() -> Hugr { + // Defining some angles for Rz gates in radians + let alpha = 0.23; + let beta = 1.78; + let inverse_angle = -alpha - beta; + + // Defining builder for circuit + let qb_row = vec![qb_t(); 1]; + let meas_row = vec![bool_type(); 1]; + let mut builder = + DFGBuilder::new(Signature::new(qb_row.clone(), meas_row.clone())).unwrap(); + let [q1] = builder.input_wires_arr(); + + // Adding constant wires and nodes + let alpha_const = builder.add_constant(Value::extension( + ConstRotation::from_radians(alpha).unwrap(), + )); + let loaded_alpha = builder.load_const(&alpha_const); + let beta_const = + builder.add_constant(Value::extension(ConstRotation::from_radians(beta).unwrap())); + let loaded_beta = builder.load_const(&beta_const); + let inverse_const = builder.add_constant(Value::extension( + ConstRotation::from_radians(inverse_angle).unwrap(), + )); + let loaded_inverse = builder.load_const(&inverse_const); + + // Adding gates and measurements + let had1 = builder.add_dataflow_op(TketOp::H, [q1]).unwrap(); + let [q1] = had1.outputs_arr(); + let rz_alpha = builder + .add_dataflow_op(TketOp::Rz, [q1, loaded_alpha]) + .unwrap(); + let [q1] = rz_alpha.outputs_arr(); + let rz_beta = builder + .add_dataflow_op(TketOp::Rz, [q1, loaded_beta]) + .unwrap(); + let [q1] = rz_beta.outputs_arr(); + let rz_inverse = builder + .add_dataflow_op(TketOp::Rz, [q1, loaded_inverse]) + .unwrap(); + let [q1] = rz_inverse.outputs_arr(); + let had2 = builder.add_dataflow_op(TketOp::H, [q1]).unwrap(); + let [q1] = had2.outputs_arr(); + let meas_res = builder + .add_dataflow_op(TketOp::MeasureFree, [q1]) + .unwrap() + .out_wire(0); + + builder + .finish_hugr_with_outputs([meas_res]) + .unwrap_or_else(|e| panic!("{e}")) + } + + fn build_non_trivial_circ_2qubits() -> Hugr { + // Defining some angles for Rz gates in radians + let alpha = 0.23; + let beta = 1.78; + let inverse_angle = -alpha - beta; + + // Defining builder for circuit + let qb_row = vec![qb_t(); 2]; + let meas_row = vec![bool_type(); 2]; + let mut builder = + DFGBuilder::new(Signature::new(qb_row.clone(), meas_row.clone())).unwrap(); + let [q1, q2] = builder.input_wires_arr(); + + // Adding constant wires and nodes + let alpha_const = builder.add_constant(Value::extension( + ConstRotation::from_radians(alpha).unwrap(), + )); + let loaded_alpha = builder.load_const(&alpha_const); + let beta_const = + builder.add_constant(Value::extension(ConstRotation::from_radians(beta).unwrap())); + let loaded_beta = builder.load_const(&beta_const); + let inverse_const = builder.add_constant(Value::extension( + ConstRotation::from_radians(inverse_angle).unwrap(), + )); + let loaded_inverse = builder.load_const(&inverse_const); + + // Adding gates and measurements + let had1 = builder.add_dataflow_op(TketOp::H, [q1]).unwrap(); + let [q1] = had1.outputs_arr(); + let rz_alpha = builder + .add_dataflow_op(TketOp::Rz, [q1, loaded_alpha]) + .unwrap(); + let [q1] = rz_alpha.outputs_arr(); + let rz_beta = builder + .add_dataflow_op(TketOp::Rz, [q1, loaded_beta]) + .unwrap(); + let [q1] = rz_beta.outputs_arr(); + let x = builder.add_dataflow_op(TketOp::X, [q2]).unwrap(); + let [q2] = x.outputs_arr(); + let cx1 = builder.add_dataflow_op(TketOp::CX, [q2, q1]).unwrap(); + let [q2, q1] = cx1.outputs_arr(); + let rz_inverse = builder + .add_dataflow_op(TketOp::Rz, [q1, loaded_inverse]) + .unwrap(); + let [q1] = rz_inverse.outputs_arr(); + let cx2 = builder.add_dataflow_op(TketOp::CX, [q2, q1]).unwrap(); + let [q2, q1] = cx2.outputs_arr(); + let had2 = builder.add_dataflow_op(TketOp::H, [q1]).unwrap(); + let [q1] = had2.outputs_arr(); + let meas_res1 = builder + .add_dataflow_op(TketOp::MeasureFree, [q1]) + .unwrap() + .out_wire(0); + let meas_res2 = builder + .add_dataflow_op(TketOp::MeasureFree, [q2]) + .unwrap() + .out_wire(0); + + builder + .finish_hugr_with_outputs([meas_res1, meas_res2]) + .unwrap_or_else(|e| panic!("{e}")) + } + + #[test] + fn gridsynth_pass_successful() { + // This test is just to check if a panic occurs + let (mut circ, _) = build_rz_only_circ(); + let epsilon: f64 = 1e-3; + apply_gridsynth_pass(&mut circ, epsilon, true).unwrap(); + } + + #[test] + fn test_non_trivial_circ_1qubit() { + // Due to challenge of accessing Selene from rust, this just + // tests for if errors occur. It would be nice to have a call to + // Selene here. (See https://github.com/Quantinuum/gridsynth_guppy_demo.git for a Python example + // of this circuit working) + let epsilon = 1e-2; + let mut hugr = build_non_trivial_circ(); + + apply_gridsynth_pass(&mut hugr, epsilon, true).unwrap(); + } + + #[test] + fn test_non_trivial_circ_2qubits() { + // Due to challenge of accessing Selene from rust, this just + // tests for if errors occur. It would be nice to have a call to + // Selene here. (See https://github.com/Quantinuum/gridsynth_guppy_demo.git for a Python example + // of this circuit working) + let epsilon = 1e-2; + let mut hugr = build_non_trivial_circ_2qubits(); + + apply_gridsynth_pass(&mut hugr, epsilon, true).unwrap(); + } +} diff --git a/tket1-passes/conanfile.txt b/tket1-passes/conanfile.txt index d235d38fd..01871ccce 100644 --- a/tket1-passes/conanfile.txt +++ b/tket1-passes/conanfile.txt @@ -1,2 +1,2 @@ [requires] -tket-c-api/2.1.64@tket/stable +tket-c-api/2.1.67@tket/stable