Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions tket-py/src/rewrite.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
//! PyO3 wrapper for rewriters.

use derive_more::From;
use hugr::HugrView;
use hugr::{hugr::views::SiblingSubgraph, HugrView};
use itertools::Itertools;
use pyo3::prelude::*;
use std::path::PathBuf;
use tket::{
rewrite::{CircuitRewrite, ECCRewriter, Rewriter, Subcircuit},
rewrite::{CircuitRewrite, ECCRewriter, Rewriter},
Circuit,
};

Expand Down Expand Up @@ -59,7 +59,7 @@ impl PyCircuitRewrite {
Ok(Self {
rewrite: CircuitRewrite::try_new(
&source_position.0,
&source_circ.circ,
source_circ.circ.hugr(),
replacement.circ,
)
.map_err(|e| PyErr::new::<pyo3::exceptions::PyValueError, _>(e.to_string()))?,
Expand Down Expand Up @@ -103,15 +103,15 @@ impl Rewriter for PyRewriter {
#[pyo3(name = "Subcircuit")]
#[derive(Debug, Clone, From)]
#[repr(transparent)]
pub struct PySubcircuit(Subcircuit);
pub struct PySubcircuit(SiblingSubgraph);

#[pymethods]
impl PySubcircuit {
#[new]
fn from_nodes(nodes: Vec<PyNode>, circ: &Tk2Circuit) -> PyResult<Self> {
let nodes: Vec<_> = nodes.into_iter().map_into().collect();
Ok(Self(
Subcircuit::try_from_nodes(nodes, &circ.circ)
SiblingSubgraph::try_from_nodes(nodes, circ.circ.hugr())
.map_err(|e| PyErr::new::<pyo3::exceptions::PyValueError, _>(e.to_string()))?,
))
}
Expand Down
15 changes: 8 additions & 7 deletions tket/src/passes/tuple_unpack.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,14 @@ use core::panic;
use hugr::builder::{DFGBuilder, Dataflow, DataflowHugr};
use hugr::extension::prelude::{MakeTuple, TupleOpDef};
use hugr::extension::simple_op::MakeExtensionOp;
use hugr::hugr::views::SiblingSubgraph;
use hugr::ops::{OpTrait, OpType};
use hugr::types::Type;
use hugr::{HugrView, Node};
use itertools::Itertools;

use crate::circuit::Command;
use crate::rewrite::{CircuitRewrite, Subcircuit};
use crate::rewrite::CircuitRewrite;
use crate::Circuit;

/// Find tuple pack operations followed by tuple unpack operations
Expand Down Expand Up @@ -102,8 +103,8 @@ fn remove_pack_unpack<T: HugrView<Node = Node>>(

let mut nodes = unpack_nodes;
nodes.push(pack_node);
let subcirc = Subcircuit::try_from_nodes(nodes, circ).unwrap();
let subcirc_signature = subcirc.signature(circ);
let subgraph = SiblingSubgraph::try_from_nodes(nodes, circ.hugr()).expect("is convex");
let subcirc_signature = subgraph.signature(circ.hugr());

// The output port order in `Subcircuit::try_from_nodes` is not too well defined.
// Check that the outputs are in the expected order.
Expand Down Expand Up @@ -142,14 +143,14 @@ fn remove_pack_unpack<T: HugrView<Node = Node>>(
.finish_hugr_with_outputs(outputs)
.unwrap_or_else(|e| {
panic!("Failed to create replacement for removing tuple pack/unpack operations. {e}")
})
.into();
});

subcirc
.create_rewrite(circ, replacement)
subgraph
.create_simple_replacement(circ.hugr(), replacement)
.unwrap_or_else(|e| {
panic!("Failed to create rewrite for removing tuple pack/unpack operations. {e}")
})
.into()
}

#[cfg(test)]
Expand Down
20 changes: 9 additions & 11 deletions tket/src/portmatching/matcher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,7 @@ use portmatching::{
};
use smol_str::SmolStr;

use crate::{
circuit::Circuit,
rewrite::{CircuitRewrite, Subcircuit},
};
use crate::{circuit::Circuit, rewrite::CircuitRewrite};

/// Matchable operations in a circuit.
#[derive(
Expand Down Expand Up @@ -75,7 +72,8 @@ fn encode_op(op: OpType) -> Option<Vec<u8>> {
/// pattern from the matcher.
#[derive(Clone)]
pub struct PatternMatch {
position: Subcircuit,
/// The matched subgraph.
subgraph: SiblingSubgraph,
pattern: PatternID,
/// The root of the pattern in the circuit.
///
Expand All @@ -96,13 +94,13 @@ impl PatternMatch {
}

/// Returns the matched subcircuit in the original circuit.
pub fn subcircuit(&self) -> &Subcircuit {
&self.position
pub fn subgraph(&self) -> &SiblingSubgraph {
&self.subgraph
}

/// Returns the matched nodes in the original circuit.
pub fn nodes(&self) -> &[Node] {
self.position.nodes()
self.subgraph.nodes()
}

/// Create a pattern match from the image of a pattern root.
Expand Down Expand Up @@ -205,7 +203,7 @@ impl PatternMatch {
let subgraph =
SiblingSubgraph::try_new_with_checker(inputs, outputs, circ.hugr(), checker)?;
Ok(Self {
position: subgraph.into(),
subgraph,
pattern,
root,
})
Expand All @@ -217,15 +215,15 @@ impl PatternMatch {
source: &Circuit<impl HugrView<Node = Node>>,
target: Circuit,
) -> Result<CircuitRewrite, InvalidReplacement> {
CircuitRewrite::try_new(&self.position, source, target)
CircuitRewrite::try_new(&self.subgraph, source.hugr(), target)
}
}

impl Debug for PatternMatch {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PatternMatch")
.field("root", &self.root)
.field("nodes", &self.position.subgraph.nodes())
.field("nodes", &self.subgraph.nodes())
.finish()
}
}
Expand Down
80 changes: 8 additions & 72 deletions tket/src/rewrite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,14 @@ pub mod ecc_rewriter;
pub mod strategy;
pub mod trace;

use bytemuck::TransparentWrapper;
#[cfg(feature = "portmatching")]
pub use ecc_rewriter::ECCRewriter;

use derive_more::{From, Into};
use hugr::core::HugrNode;
use hugr::hugr::hugrmut::HugrMut;
use hugr::hugr::patch::simple_replace;
use hugr::hugr::views::sibling_subgraph::{InvalidReplacement, InvalidSubgraph};
use hugr::hugr::views::sibling_subgraph::InvalidReplacement;
use hugr::hugr::Patch;
use hugr::types::Signature;
use hugr::{
hugr::{views::SiblingSubgraph, SimpleReplacementError},
SimpleReplacement,
Expand All @@ -24,83 +21,22 @@ use hugr::{Hugr, HugrView, Node};

use crate::circuit::Circuit;

/// A subcircuit of a circuit.
#[derive(Debug, Clone, From, Into)]
#[repr(transparent)]
pub struct Subcircuit<N = Node> {
pub(crate) subgraph: SiblingSubgraph<N>,
}

unsafe impl<N> TransparentWrapper<SiblingSubgraph<N>> for Subcircuit<N> {}

impl<N: HugrNode> Subcircuit<N> {
/// Create a new subcircuit induced from a set of nodes.
pub fn try_from_nodes(
nodes: impl Into<Vec<N>>,
circ: &Circuit<impl HugrView<Node = N>>,
) -> Result<Self, InvalidSubgraph<N>> {
let subgraph = SiblingSubgraph::try_from_nodes(nodes, circ.hugr())?;
Ok(Self { subgraph })
}

/// Nodes in the subcircuit.
pub fn nodes(&self) -> &[N] {
self.subgraph.nodes()
}

/// Number of nodes in the subcircuit.
pub fn node_count(&self) -> usize {
self.subgraph.node_count()
}

/// The signature of the subcircuit.
pub fn signature(&self, circ: &Circuit<impl HugrView<Node = N>>) -> Signature {
self.subgraph.signature(circ.hugr())
}
}

impl Subcircuit<Node> {
/// Create a rewrite rule to replace the subcircuit with a new circuit.
///
/// # Parameters
/// * `circuit` - The base circuit that contains the subcircuit.
/// * `replacement` - The new circuit to replace the subcircuit with.
pub fn create_rewrite(
&self,
circuit: &Circuit<impl HugrView<Node = Node>>,
replacement: Circuit<impl HugrView<Node = Node>>,
) -> Result<CircuitRewrite, InvalidReplacement> {
// The replacement must be a Dfg rooted hugr.
let replacement = replacement
.extract_dfg()
.unwrap_or_else(|e| panic!("{}", e))
.into_hugr();
Ok(CircuitRewrite(
self.subgraph
.create_simple_replacement(circuit.hugr(), replacement)?,
))
}
}

/// A rewrite rule for circuits.
#[derive(Debug, Clone, From, Into)]
pub struct CircuitRewrite<N = Node>(SimpleReplacement<N>);

impl CircuitRewrite {
/// Create a new rewrite rule.
pub fn try_new(
circuit_position: &Subcircuit,
circuit: &Circuit<impl HugrView<Node = Node>>,
subgraph: &SiblingSubgraph,
hugr: &impl HugrView<Node = Node>,
replacement: Circuit<impl HugrView<Node = Node>>,
) -> Result<Self, InvalidReplacement> {
let replacement = replacement
.extract_dfg()
.unwrap_or_else(|e| panic!("{}", e))
.into_hugr();
circuit_position
.subgraph
.create_simple_replacement(circuit.hugr(), replacement)
.map(Self)
Ok(Self(subgraph.create_simple_replacement(hugr, replacement)?))
}

/// Number of nodes added or removed by the rewrite.
Expand All @@ -109,13 +45,13 @@ impl CircuitRewrite {
/// number is an increase in node count, a negative number is a decrease.
pub fn node_count_delta(&self) -> isize {
let new_count = self.replacement().num_operations() as isize;
let old_count = self.subcircuit().node_count() as isize;
let old_count = self.subgraph().node_count() as isize;
new_count - old_count
}

/// The subcircuit that is replaced.
pub fn subcircuit(&self) -> &Subcircuit {
Subcircuit::wrap_ref(self.0.subgraph())
/// The subgraph that is replaced.
pub fn subgraph(&self) -> &SiblingSubgraph {
self.0.subgraph()
}

/// The replacement subcircuit.
Expand Down
45 changes: 27 additions & 18 deletions tket/src/rewrite/strategy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,7 @@ pub trait RewriteStrategy {
/// Returns the cost of a rewrite's matched subcircuit before replacing it.
#[inline]
fn pre_rewrite_cost(&self, rw: &CircuitRewrite, circ: &Circuit) -> Self::Cost {
circ.nodes_cost(rw.subcircuit().nodes().iter().copied(), |op| {
self.op_cost(op)
})
circ.nodes_cost(rw.subgraph().nodes().iter().copied(), |op| self.op_cost(op))
}

/// Returns the expected cost of a rewrite's matched subcircuit after replacing it.
Expand Down Expand Up @@ -129,14 +127,14 @@ impl RewriteStrategy for GreedyRewriteStrategy {
let mut circ = circ.clone();
for rewrite in rewrites {
if rewrite
.subcircuit()
.subgraph()
.nodes()
.iter()
.any(|n| changed_nodes.contains(n))
{
continue;
}
changed_nodes.extend(rewrite.subcircuit().nodes().iter().copied());
changed_nodes.extend(rewrite.subgraph().nodes().iter().copied());
cost_delta += rewrite.node_count_delta();
rewrite
.apply(&mut circ)
Expand Down Expand Up @@ -474,15 +472,30 @@ impl GammaStrategyCost<fn(&OpType) -> usize> {
#[cfg(test)]
mod tests {
use super::*;
use hugr::hugr::views::SiblingSubgraph;
use hugr::Node;
use itertools::Itertools;

use crate::rewrite::trace::REWRITE_TRACING_ENABLED;
use crate::{
circuit::Circuit,
rewrite::{CircuitRewrite, Subcircuit},
utils::build_simple_circuit,
};
use crate::{circuit::Circuit, rewrite::CircuitRewrite, utils::build_simple_circuit};

/// Create a rewrite rule to replace the subcircuit with a new circuit.
/// TODO: this should use the new Subcircuit; TEMP TEST WORKAROUND until that arrives.
///
/// # Parameters
/// * `circuit` - The base circuit that contains the subcircuit.
/// * `replacement` - The new circuit to replace the subcircuit with.
fn create_rewrite(
ssg: &SiblingSubgraph<Node>,
circuit: &Circuit<impl HugrView<Node = Node>>,
replacement: Circuit<impl HugrView<Node = Node>>,
) -> CircuitRewrite {
// The replacement must be a Dfg rooted hugr.
let replacement = replacement.extract_dfg().unwrap().into_hugr();
ssg.create_simple_replacement(circuit.hugr(), replacement)
.unwrap()
.into()
}

fn n_cx(n_gates: usize) -> Circuit {
let qbs = [0, 1];
Expand All @@ -497,18 +510,14 @@ mod tests {

/// Rewrite cx_nodes -> empty
fn rw_to_empty(circ: &Circuit, cx_nodes: impl Into<Vec<Node>>) -> CircuitRewrite {
let subcirc = Subcircuit::try_from_nodes(cx_nodes, circ).unwrap();
subcirc
.create_rewrite(circ, n_cx(0))
.unwrap_or_else(|e| panic!("{}", e))
let subcirc = SiblingSubgraph::try_from_nodes(cx_nodes, circ.hugr()).unwrap();
create_rewrite(&subcirc, circ, n_cx(0))
}

/// Rewrite cx_nodes -> 10x CX
fn rw_to_full(circ: &Circuit, cx_nodes: impl Into<Vec<Node>>) -> CircuitRewrite {
let subcirc = Subcircuit::try_from_nodes(cx_nodes, circ).unwrap();
subcirc
.create_rewrite(circ, n_cx(10))
.unwrap_or_else(|e| panic!("{}", e))
let subcirc = SiblingSubgraph::try_from_nodes(cx_nodes, circ.hugr()).unwrap();
create_rewrite(&subcirc, circ, n_cx(10))
}

#[test]
Expand Down
Loading