Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions tket-py/src/passes/gridsynth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,15 @@ create_py_exception!(
/// Binding to a python function called gridsynth that runs the rust function called
/// apply_gridsynth pass behind the scenes
#[pyfunction]
pub fn gridsynth<'py>(circ: &Bound<'py, PyAny>, epsilon: f64) -> PyResult<Bound<'py, PyAny>> {
pub fn gridsynth<'py>(
circ: &Bound<'py, PyAny>,
epsilon: f64,
simplify: bool,
) -> PyResult<Bound<'py, PyAny>> {
let py = circ.py();

try_with_circ(circ, |mut circ: tket::Circuit, typ: CircuitType| {
apply_gridsynth_pass(circ.hugr_mut(), epsilon).convert_pyerrs()?;
apply_gridsynth_pass(circ.hugr_mut(), epsilon, simplify).convert_pyerrs()?;

let circ = typ.convert(py, circ)?;
PyResult::Ok(circ)
Expand Down
5 changes: 4 additions & 1 deletion tket-py/tket/_tket/passes.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -106,11 +106,14 @@ def tket1_pass(
nested inside other subregions of the circuit.
"""

def gridsynth(hugr: CircuitClass, epsilon: float) -> CircuitClass:
def gridsynth(hugr: CircuitClass, epsilon: float, simplify: bool) -> CircuitClass:
"""Runs a pass applying the gridsynth algorithm to all Rz gates in a HUGR,
which decomposes them into the Clifford + T basis.

Parameters:
- hugr: the hugr to run the pass on.
- epsilon: the precision of the gridsynth decomposition
- simplify: if `True`, each sequence of gridsynth gates is compressed into
a sequence of H*T and H*Tdg gates, sandwiched by Clifford gates. This sequence
always has a smaller number of S and H gates, and the same number of T+Tdg gates.
"""
12 changes: 9 additions & 3 deletions tket-py/tket/passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ def _normalize(self, hugr: Hugr, inplace: bool) -> PassResult:
@dataclass
class Gridsynth(ComposablePass):
epsilon: float
simplify: bool = True

"""Apply the gridsynth algorithm to all Rz gates in the Hugr.

Expand All @@ -191,6 +192,9 @@ class Gridsynth(ComposablePass):

Parameters:
- epsilon: The allowable error tolerance.
- simplify: If `True`, each sequence of gridsynth gates is compressed into
a sequence of H*T and H*Tdg gates, sandwiched by Clifford gates. This sequence
always has a smaller number of S and H gates, and the same number of T+Tdg gates.
"""

# TO DO: make the NormalizeGuppy pass optional, in case it is already run
Expand All @@ -206,13 +210,15 @@ def run(self, hugr: Hugr, *, inplace: bool = True) -> PassResult:
self,
hugr=hugr,
inplace=inplace,
copy_call=lambda h: self._apply_gridsynth_pass(hugr, self.epsilon, inplace),
copy_call=lambda h: self._apply_gridsynth_pass(
hugr, self.epsilon, self.simplify, inplace
),
)

def _apply_gridsynth_pass(
self, hugr: Hugr, epsilon: float, inplace: bool
self, hugr: Hugr, epsilon: float, simplify: bool, inplace: bool
) -> PassResult:
compiler_state: Tk2Circuit = Tk2Circuit.from_bytes(hugr.to_bytes())
program = gridsynth(compiler_state, epsilon)
program = gridsynth(compiler_state, epsilon, simplify)
new_hugr = Hugr.from_str(program.to_str())
return PassResult.for_pass(self, hugr=new_hugr, inplace=inplace, result=None)
214 changes: 107 additions & 107 deletions tket/src/passes/gridsynth.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
//! A pass that applies the gridsynth algorithm to all Rz gates in a HUGR.

/// The pass introduced here assumes that (1) all functions have been inlined and
/// (2) that NormalizeGuppy has been run. It expects that the following is guaranteed:
/// * That every Const node is immediately connected to a LoadConst node.
/// * That every constant is used in a single place (i.e. that there's a single output
/// of each LoadConst).
/// * That we can find the Const node connected to an Rz node by first following the
/// input port 1 of the Rz node (i.e. the angle argument) and then iteratively
/// following the input on port 0 until we reach a Const node.
use std::collections::HashMap;

use crate::TketOp;
Expand Down Expand Up @@ -131,8 +138,6 @@ fn find_angle_node(
// As all of the NormalizeGuppy passes have been run on the `hugr` before it enters this function,
// and these passes include constant folding, we can assume that we can follow the 0th ports back
// to a constant node where the angle is defined.
let max_iterations = 10;
let mut ii = 0;
let mut path = Vec::new(); // The nodes leading up to the angle_node and the angle_node
// itself
loop {
Expand All @@ -156,12 +161,8 @@ fn find_angle_node(
.or_insert(path);
return angle_node;
}
if ii >= max_iterations {
panic!("Angle finding failed");
}

prev_node = current_node;
ii += 1;
}
}

Expand All @@ -188,10 +189,15 @@ fn find_angle(hugr: &mut Hugr, rz_node: Node, garbage_collector: &mut GarbageCol
angle
}

fn apply_gridsynth(
/// Call gridsynth on the angle of the Rz gate node provided.
/// If `simplify` is `true`, the sequence of gridsynth gates is compressed into
/// a sequence of H*T and H*Tdg gates, sandwiched by Clifford gates. This sequence
/// always has a smaller number of S and H gates, and the same number of T+Tdg gates.
fn get_gridsynth_gates(
hugr: &mut Hugr,
epsilon: f64,
rz_node: Node,
epsilon: f64,
simplify: bool,
garbage_collector: &mut GarbageCollector,
) -> String {
let theta = find_angle(hugr, rz_node, garbage_collector);
Expand All @@ -200,113 +206,107 @@ fn apply_gridsynth(
let up_to_phase = false;
let mut gridsynth_config =
config_from_theta_epsilon(theta, epsilon, seed, verbose, up_to_phase);
let gates = gridsynth_gates(&mut gridsynth_config);
gates.gates
let mut gate_sequence = gridsynth_gates(&mut gridsynth_config).gates;

if simplify {
let n = gate_sequence.len();
let mut normal_form_reached = false;
while !normal_form_reached {
// Not the most efficient, but it's easiest to reach the normal form by doing
// string rewrites.
// TODO: Can be done with Regex, preferably by providing all LHS to the Regex

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just wondering, do you have any idea of the efficiency of using regex vs these replacements?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not really, other than it's usually preferred. The two main issues with performance of the current approach are:

  • It is not changing the Strings in place, but creating multiple of them. As far as I can tell, Regex doesn't help with that. According to the docs: "The only methods that allocate new strings are the string replacement methods. All other methods (searching and splitting) return borrowed references into the haystack given."
  • The other thing is that we scan the String once per iteration and per LHS of a rewrite rule. Ideally, we would find all matches of the same scan and replace them all at once. This can probably be done with Regex' replace_all since it looks like the Replacer can be a closure |&Captures| -> String (see examples in the replace docs).

After staring at the docs of Regex for 30min, I decided it was not worth doing it (I couldn't figure out how to do it) until we've got some evidence that we need better performance here.

// so they are all matched at once; then we replace each one accordingly.
// NOTE: Ignoring global phase factors
let new_gate_sequence = gate_sequence
// Cancellation rules
.replacen("ZZ", "", n)
.replacen("XX", "", n)
.replacen("HH", "", n)
.replacen("SS", "Z", n)
.replacen("TT", "S", n)
.replacen("DD", "SZ", n)
.replacen("TD", "", n)
.replacen("DT", "", n)
// Rules to push Paulis to the right
.replacen("ZS", "SZ", n)
.replacen("ZT", "TZ", n)
.replacen("ZD", "DZ", n)
.replacen("XS", "SZX", n)
.replacen("XT", "DX", n)
.replacen("XD", "TX", n)
.replacen("ZH", "HX", n)
.replacen("XH", "HZ", n)
.replacen("XZ", "ZX", n)
// Interaction of H and S (reduces number of H)
.replacen("HSH", "SHSX", n)
// Interaction of S and T (reduces number of S)
.replacen("DS", "T", n)
.replacen("SD", "T", n)
.replacen("TS", "DZ", n)
.replacen("ST", "DZ", n);
// Stop when no more changes are possible
normal_form_reached = new_gate_sequence == gate_sequence;
gate_sequence = new_gate_sequence;
}
}
gate_sequence
}

/// Add a gridsynth gate to some previous node, which may or may not be a gridsynth gate,
/// and connect
fn add_gate_and_connect(
hugr: &mut Hugr,
prev_node: Node,
op: hugr::ops::OpType,
output_node: Node,
qubit_providing_node: Node, // The node providing qubit to Rz gate
qubit_providing_port: OutgoingPort, // The output port providing qubit to Rz gate
) -> Node {
let current_node = hugr.add_node_after(output_node, op);
let ports: Vec<_> = hugr.node_outputs(prev_node).collect();

// If the previous node was the qubit_providing_node then it could have multiple
// outputs (eg, if multi-qubit gate and so need to be explicit about port)
let src_port = if prev_node.index() == qubit_providing_node.index() {
qubit_providing_port
} else {
// the ops generated by gridsynth are all single input single output gates, so
// it is safe to assume that there is only one output port
ports[0]
};
let ports: Vec<_> = hugr.node_inputs(current_node).collect();
let dst_port = ports[0];
hugr.connect(prev_node, src_port, current_node, dst_port);

current_node
} // TO DO: reduce number of arguments to this function. Six is too many.

fn replace_rz_with_gridsynth_output(
hugr: &mut Hugr,
rz_node: Node,
gates: &str,
) -> Result<(), GridsynthError> {
// getting node and output port that gave qubit to Rz gate
let inputs: Vec<_> = hugr.node_inputs(rz_node).collect();
let input_port = inputs[0];
let (qubit_providing_node, qubit_providing_port) =
hugr.single_linked_output(rz_node, input_port).unwrap();
let mut prev_node = qubit_providing_node;

// find output port
let outputs: Vec<_> = hugr.node_outputs(rz_node).collect();
let output_port = outputs[0];
let (next_node, dst_port) = hugr.single_linked_input(rz_node, output_port).unwrap();

// we have now inferred what we need to know from the Rz node we are replacing and can remove it
// Remove the W i.e. the exp(i*pi/4) global phases
let gate_sequence = gates.replacen("W", "", gates.len());
// Add the nodes of the gridsynth sequence into the HUGR
let gridsynth_nodes: Vec<Node> = gate_sequence

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice one! I had a TODO to rewrite this so thanks for taking care of it :)

.chars()
.map(|gate| match gate {
'H' => TketOp::H,
'S' => TketOp::S,
'T' => TketOp::T,
'D' => TketOp::Tdg,
'X' => TketOp::X,
'Z' => TketOp::Z,
_ => panic!("The gate {gate} is not supported"),
})
.map(|op| hugr.add_node_after(rz_node, op)) // No connections just yet
.collect(); // Force the nodes to actually be added

// Get the node that connects to the input of the Rz gate
let rz_input_port = hugr.node_inputs(rz_node).next().unwrap();
let (mut prev_node, mut prev_port) = hugr.single_linked_output(rz_node, rz_input_port).unwrap();
// Get the node that connects to the output of the Rz gate
let rz_output_port = hugr.node_outputs(rz_node).next().unwrap();
let (next_node, next_port) = hugr.single_linked_input(rz_node, rz_output_port).unwrap();
// We have now inferred what we need to know from the Rz node; we can remove it
hugr.remove_node(rz_node);

// recursively adding next gate in gates to prev_node
for gate in gates.chars() {
if gate == 'H' {
prev_node = add_gate_and_connect(
hugr,
prev_node,
TketOp::H.into(),
next_node,
qubit_providing_node,
qubit_providing_port,
);
} else if gate == 'S' {
prev_node = add_gate_and_connect(
hugr,
prev_node,
TketOp::S.into(),
next_node,
qubit_providing_node,
qubit_providing_port,
);
} else if gate == 'T' {
prev_node = add_gate_and_connect(
hugr,
prev_node,
TketOp::T.into(),
next_node,
qubit_providing_node,
qubit_providing_port,
);
} else if gate == 'X' {
prev_node = add_gate_and_connect(
hugr,
prev_node,
TketOp::X.into(),
next_node,
qubit_providing_node,
qubit_providing_port,
);
} else if gate == 'W' {
break; // Ignoring global phases for now.
} else {
panic!("The gate {gate} is not supported")
}
// Connect the gridsynth nodes
for current_node in gridsynth_nodes {
// Connect the current node with the previous node
let current_port = hugr.node_inputs(current_node).next().unwrap();
hugr.connect(prev_node, prev_port, current_node, current_port);
// Update who is the prev_node
prev_node = current_node;
prev_port = hugr.node_outputs(prev_node).next().unwrap();
}
let ports: Vec<_> = hugr.node_outputs(prev_node).collect();
// Assuming there were no outgoing ports to begin with when deciding port offset
let src_port = ports[0];
hugr.connect(prev_node, src_port, next_node, dst_port);
hugr.validate()?;
// Finally, connect the last gridsynth node with the node that came after the Rz
hugr.connect(prev_node, prev_port, next_node, next_port);
// hugr.validate()?;
Ok(())
}

/// Replace an Rz gate with the corresponding gates outputted by gridsynth
pub fn apply_gridsynth_pass(hugr: &mut Hugr, epsilon: f64) -> Result<(), GridsynthError> {
/// Replace an Rz gate with the corresponding gates outputted by gridsynth.
/// If `simplify` is `true`, the sequence of gridsynth gates is compressed into
/// a sequence of H*T and H*Tdg gates, sandwiched by Clifford gates. This sequence
/// always has a smaller number of S and H gates, and the same number of T+Tdg gates.
pub fn apply_gridsynth_pass(
hugr: &mut Hugr,
epsilon: f64,
simplify: bool,
) -> Result<(), GridsynthError> {
// Running passes to convert HUGR to standard form
NormalizeGuppy::default()
.simplify_cfgs(true)
Expand All @@ -322,7 +322,7 @@ pub fn apply_gridsynth_pass(hugr: &mut Hugr, epsilon: f64) -> Result<(), Gridsyn
path: HashMap::new(),
};
for node in rz_nodes {
let gates = apply_gridsynth(hugr, epsilon, node, &mut garbage_collector);
let gates = get_gridsynth_gates(hugr, node, epsilon, simplify, &mut garbage_collector);
replace_rz_with_gridsynth_output(hugr, node, &gates)?;
}
Ok(())
Expand Down Expand Up @@ -490,7 +490,7 @@ mod tests {
// This test is just to check if a panic occurs
let (mut circ, _) = build_rz_only_circ();
let epsilon: f64 = 1e-3;
apply_gridsynth_pass(&mut circ, epsilon).unwrap();
apply_gridsynth_pass(&mut circ, epsilon, true).unwrap();
}

#[test]
Expand All @@ -502,7 +502,7 @@ mod tests {
let epsilon = 1e-2;
let mut hugr = build_non_trivial_circ();

apply_gridsynth_pass(&mut hugr, epsilon).unwrap();
apply_gridsynth_pass(&mut hugr, epsilon, true).unwrap();
}

#[test]
Expand All @@ -514,6 +514,6 @@ mod tests {
let epsilon = 1e-2;
let mut hugr = build_non_trivial_circ_2qubits();

apply_gridsynth_pass(&mut hugr, epsilon).unwrap();
apply_gridsynth_pass(&mut hugr, epsilon, true).unwrap();
}
}
Loading