Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 20 additions & 29 deletions tket/src/serialize/pytket/decoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -312,33 +312,36 @@ impl<'h> PytketDecoderContext<'h> {
encoded_info: Option<&EncodedCircuitInfo>,
) -> Result<Node, PytketDecodeError> {
// Order the final wires according to the serial circuit register order.
let known_qubits = self
.wire_tracker
.known_pytket_qubits()
.cloned()
.collect_vec();
let known_bits = self.wire_tracker.known_pytket_bits().cloned().collect_vec();
let mut known_qubits: IndexSet<TrackedQubit> =
self.wire_tracker.known_pytket_qubits().cloned().collect();
let mut known_bits: IndexSet<TrackedBit> =
self.wire_tracker.known_pytket_bits().cloned().collect();

// Qubits and bits appearing at the output.
let mut qubits: IndexSet<TrackedQubit> = IndexSet::new();
let mut bits: IndexSet<TrackedBit> = IndexSet::new();
let mut qubits: Vec<TrackedQubit> = Vec::new();
let mut bits: Vec<TrackedBit> = Vec::new();
Comment on lines +321 to +322
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the main fix.

We populated this with the registers declared at the Hugr output.

If the output said e.g. b[0], b[0], b[1] we ignored the second entry, and ended up outputting b[0], b[1], new_allocated_bit


if let Some(encoded_info) = encoded_info {
for qubit in encoded_info.output_qubits.iter() {
let id = self.wire_tracker.tracked_qubit_for_register(qubit)?;
qubits.insert(id.clone());
qubits.push(id.clone());
}
for bit in encoded_info.output_bits.iter() {
let id = self.wire_tracker.tracked_bit_for_register(bit)?;
bits.insert(id.clone());
bits.push(id.clone());
}
}
// Add any additional qubits or bits we have seen, without modifying the
// order of the qubits already there.

// Add any additional qubits or bits we have seen but haven't added to the list.
for q in &qubits {
known_qubits.shift_remove(q);
}
for b in &bits {
known_bits.shift_remove(b);
}
qubits.extend(known_qubits);
bits.extend(known_bits);
let qubits: Vec<TrackedQubit> = Vec::from_iter(qubits);
let bits: Vec<TrackedBit> = Vec::from_iter(bits);

let mut qubits_slice: &[TrackedQubit] = &qubits;
let mut bits_slice: &[TrackedBit] = &bits;

Expand Down Expand Up @@ -623,11 +626,6 @@ impl<'h> PytketDecoderContext<'h> {
/// and pytket parameters. Registers the node's output wires in the wire
/// tracker.
///
/// The qubits registers in `wires` are reused between the operation inputs
/// and outputs. Bit registers, on the other hand, are not reused. We use
/// the first registers in `wires` for the bit inputs and the remaining
/// registers for the outputs.
///
/// The input wire types must match the operation's input signature, no type
/// conversion is performed.
///
Expand Down Expand Up @@ -741,9 +739,6 @@ impl<'h> PytketDecoderContext<'h> {
.hugr_mut()
.connect(wire.node(), wire.source(), node, input_idx);
}
input_bits.iter().take(op_input_count.bits).for_each(|b| {
self.wire_tracker.mark_bit_outdated(b.clone());
});

// Register the output wires.
let output_qubits = output_qubits.iter().take(op_output_count.qubits).cloned();
Expand Down Expand Up @@ -814,9 +809,8 @@ impl<'h> PytketDecoderContext<'h> {
/// Given a new node in the HUGR, register all of its output wires in the
/// tracker.
///
/// Consumes the bits and qubits in order. Any unused bits and qubits are
/// marked as outdated, as they are assumed to have been consumed in the
/// inputs.
/// Consumes the bits and qubits in order. Any unused qubits are marked as
/// outdated, as they are assumed to have been consumed in the inputs.
pub fn register_node_outputs(
&mut self,
node: Node,
Expand Down Expand Up @@ -861,13 +855,10 @@ impl<'h> PytketDecoderContext<'h> {
.track_wire(wire, Arc::new(ty.clone()), wire_qubits, wire_bits)?;
}

// Mark any unused qubits and bits as outdated.
// Mark any unused qubits as outdated.
qubits.for_each(|q| {
self.wire_tracker.mark_qubit_outdated(q);
});
bits.for_each(|b| {
self.wire_tracker.mark_bit_outdated(b);
});

Ok(())
}
Expand Down
24 changes: 13 additions & 11 deletions tket/src/serialize/pytket/decoder/wires.rs
Original file line number Diff line number Diff line change
Expand Up @@ -686,14 +686,16 @@ impl WireTracker {
};

// List candidate wires that contain the qubits and bits we need.
let qubit_candidates = qubit_args
.first()
.into_iter()
.flat_map(|qb| self.qubit_wires(qb));
let bit_candidates = bit_args
.first()
.into_iter()
.flat_map(|bit| self.bit_wires(bit));
let qubit_candidates = if reg_count.qubits > 0 && !qubit_args.is_empty() {
itertools::Either::Left(self.qubit_wires(&qubit_args[0]))
} else {
itertools::Either::Right(std::iter::empty())
};
let bit_candidates = if reg_count.bits > 0 && !bit_args.is_empty() {
itertools::Either::Left(self.bit_wires(&bit_args[0]))
} else {
itertools::Either::Right(std::iter::empty())
};
let candidates = qubit_candidates.chain(bit_candidates).collect_vec();

// The bits and qubits we expect the wire to contain.
Expand Down Expand Up @@ -766,13 +768,13 @@ impl WireTracker {

// Convert the wire type, if needed.
let wire_data = &self.wires[&wire];
let new_wire = config.transform_typed_value(wire, wire_data.ty(), ty, builder)?;
let found_wire_type = wire_data.ty();
let new_wire = config.transform_typed_value(wire, found_wire_type, ty, builder)?;

if wire == new_wire {
Ok(FoundWire::Register(self.wires[&wire].clone()))
} else {
let ty: Arc<Type> = wire_data.ty.clone();
self.track_wire(new_wire, ty, wire_qubits, wire_bits)?;
self.track_wire(new_wire, Arc::new(ty.clone()), wire_qubits, wire_bits)?;
self.mark_wire_outdated(wire);
Ok(FoundWire::Register(self.wires[&new_wire].clone()))
}
Expand Down
10 changes: 8 additions & 2 deletions tket/src/serialize/pytket/extension/bool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,15 +59,21 @@ impl<H: HugrView> PytketEmitter<H> for BoolEmitter {
// variable inputs. If new [`BoolOp`]s are added that do not follow
// this, the following code will need to be adjusted.
let bit_count = (num_inputs + num_outputs) as usize;
let output_bits = (0..num_outputs).collect_vec();
let output_bits = (num_inputs..(num_inputs + num_outputs)).collect_vec();
let mut expression = ClOperator::default();
expression.op = clop;
expression.args = (0..num_inputs)
.map(|i| ClArgument::Terminal(ClTerminal::Variable(ClVariable::Bit { index: i })))
.collect_vec();

let op = make_tk1_classical_expression(bit_count, &output_bits, &[], expression);
encoder.emit_node_command(node, hugr, EmitCommandOptions::new(), move |_| op)?;
encoder.emit_node_command(
node,
hugr,
// Output bits use new registers, so don't reuse any input bits.
EmitCommandOptions::new().reuse_bits(|_| vec![]),
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Drive-by: The classical expression emitter was re-using the input bits to the node instead of using a separate output bit register.

move |_| op,
)?;
Ok(EncodeStatus::Success)
}

Expand Down
33 changes: 32 additions & 1 deletion tket/src/serialize/pytket/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use std::sync::Arc;
use super::TKETDecode;
use crate::TketOp;
use crate::extension::TKET1_EXTENSION_ID;
use crate::extension::bool::{BoolOp, bool_type};
use crate::extension::bool::{BoolOp, ConstBool, bool_type};
use crate::extension::rotation::{ConstRotation, RotationOp, rotation_type};
use crate::extension::sympy::SympyOpDef;
use crate::metadata;
Expand Down Expand Up @@ -287,6 +287,35 @@ fn circ_preset_qubits() -> Hugr {
hugr
}

/// A simple circuit with some preset input and output bit registers,
/// including multiple outputs for the same register.
#[fixture]
fn circ_preset_bits() -> Hugr {
let input_t = vec![bool_type()];
let output_t = vec![bool_type(), bool_type(), bool_type()];
let mut h = FunctionBuilder::new("preset_bits", Signature::new(input_t, output_t)).unwrap();

let [b0] = h.input_wires_arr();
let b1 = h.add_load_value(ConstBool::new(false));
let [b_and] = h
.add_dataflow_op(BoolOp::and, [b0, b1])
.unwrap()
.outputs_arr();

let mut hugr = h.finish_hugr_with_outputs([b0, b_and, b0]).unwrap();

// A preset register for the first qubit output
hugr.set_metadata::<metadata::BitRegisters>(
hugr.entrypoint(),
vec![ElementId(String::from("b"), vec![1])]
.into_iter()
.map(register::Bit::from)
.collect_vec(),
);

hugr
}

/// A simple circuit with some input parameters
#[fixture]
fn circ_parameterized() -> Hugr {
Expand Down Expand Up @@ -961,6 +990,7 @@ fn encoded_circuit_attributes(circ_measure_ancilla: Hugr) {
#[rstest]
#[case::meas_ancilla(circ_measure_ancilla(), CircuitRoundtripTestConfig::Default)]
#[case::preset_qubits(circ_preset_qubits(), CircuitRoundtripTestConfig::Default)]
#[case::preset_bits(circ_preset_bits(), CircuitRoundtripTestConfig::Default)]
#[case::preset_parameterized(circ_parameterized(), CircuitRoundtripTestConfig::Default)]
// TODO: Should pass once CircBox encoding of DFGs is re-enabled.
#[should_panic(expected = "Cannot encode subgraphs with nested structure")]
Expand Down Expand Up @@ -1068,6 +1098,7 @@ fn fail_on_modified_hugr(circ_tk1_ops: Hugr) {
#[rstest]
#[case::meas_ancilla(circ_measure_ancilla(), 1, CircuitRoundtripTestConfig::Default)]
#[case::preset_qubits(circ_preset_qubits(), 1, CircuitRoundtripTestConfig::Default)]
#[case::preset_bits(circ_preset_bits(), 1, CircuitRoundtripTestConfig::Default)]
#[case::preset_parameterized(circ_parameterized(), 1, CircuitRoundtripTestConfig::Default)]
#[case::nested_dfgs(circ_nested_dfgs(), 2, CircuitRoundtripTestConfig::Default)]
#[case::flat_opaque(circ_tk1_ops(), 1, CircuitRoundtripTestConfig::Default)]
Expand Down
Loading