diff --git a/tket/src/serialize/pytket/decoder.rs b/tket/src/serialize/pytket/decoder.rs index 0c4b1d332..2c0e92572 100644 --- a/tket/src/serialize/pytket/decoder.rs +++ b/tket/src/serialize/pytket/decoder.rs @@ -312,33 +312,36 @@ impl<'h> PytketDecoderContext<'h> { encoded_info: Option<&EncodedCircuitInfo>, ) -> Result { // Order the final wires according to the serial circuit register order. - let known_qubits = self - .wire_tracker - .known_pytket_qubits() - .cloned() - .collect_vec(); - let known_bits = self.wire_tracker.known_pytket_bits().cloned().collect_vec(); + let mut known_qubits: IndexSet = + self.wire_tracker.known_pytket_qubits().cloned().collect(); + let mut known_bits: IndexSet = + self.wire_tracker.known_pytket_bits().cloned().collect(); // Qubits and bits appearing at the output. - let mut qubits: IndexSet = IndexSet::new(); - let mut bits: IndexSet = IndexSet::new(); + let mut qubits: Vec = Vec::new(); + let mut bits: Vec = Vec::new(); if let Some(encoded_info) = encoded_info { for qubit in encoded_info.output_qubits.iter() { let id = self.wire_tracker.tracked_qubit_for_register(qubit)?; - qubits.insert(id.clone()); + qubits.push(id.clone()); } for bit in encoded_info.output_bits.iter() { let id = self.wire_tracker.tracked_bit_for_register(bit)?; - bits.insert(id.clone()); + bits.push(id.clone()); } } - // Add any additional qubits or bits we have seen, without modifying the - // order of the qubits already there. + + // Add any additional qubits or bits we have seen but haven't added to the list. + for q in &qubits { + known_qubits.shift_remove(q); + } + for b in &bits { + known_bits.shift_remove(b); + } qubits.extend(known_qubits); bits.extend(known_bits); - let qubits: Vec = Vec::from_iter(qubits); - let bits: Vec = Vec::from_iter(bits); + let mut qubits_slice: &[TrackedQubit] = &qubits; let mut bits_slice: &[TrackedBit] = &bits; @@ -623,11 +626,6 @@ impl<'h> PytketDecoderContext<'h> { /// and pytket parameters. Registers the node's output wires in the wire /// tracker. /// - /// The qubits registers in `wires` are reused between the operation inputs - /// and outputs. Bit registers, on the other hand, are not reused. We use - /// the first registers in `wires` for the bit inputs and the remaining - /// registers for the outputs. - /// /// The input wire types must match the operation's input signature, no type /// conversion is performed. /// @@ -741,9 +739,6 @@ impl<'h> PytketDecoderContext<'h> { .hugr_mut() .connect(wire.node(), wire.source(), node, input_idx); } - input_bits.iter().take(op_input_count.bits).for_each(|b| { - self.wire_tracker.mark_bit_outdated(b.clone()); - }); // Register the output wires. let output_qubits = output_qubits.iter().take(op_output_count.qubits).cloned(); @@ -814,9 +809,8 @@ impl<'h> PytketDecoderContext<'h> { /// Given a new node in the HUGR, register all of its output wires in the /// tracker. /// - /// Consumes the bits and qubits in order. Any unused bits and qubits are - /// marked as outdated, as they are assumed to have been consumed in the - /// inputs. + /// Consumes the bits and qubits in order. Any unused qubits are marked as + /// outdated, as they are assumed to have been consumed in the inputs. pub fn register_node_outputs( &mut self, node: Node, @@ -861,13 +855,10 @@ impl<'h> PytketDecoderContext<'h> { .track_wire(wire, Arc::new(ty.clone()), wire_qubits, wire_bits)?; } - // Mark any unused qubits and bits as outdated. + // Mark any unused qubits as outdated. qubits.for_each(|q| { self.wire_tracker.mark_qubit_outdated(q); }); - bits.for_each(|b| { - self.wire_tracker.mark_bit_outdated(b); - }); Ok(()) } diff --git a/tket/src/serialize/pytket/decoder/wires.rs b/tket/src/serialize/pytket/decoder/wires.rs index 54ab448fa..7988e5d4a 100644 --- a/tket/src/serialize/pytket/decoder/wires.rs +++ b/tket/src/serialize/pytket/decoder/wires.rs @@ -686,14 +686,16 @@ impl WireTracker { }; // List candidate wires that contain the qubits and bits we need. - let qubit_candidates = qubit_args - .first() - .into_iter() - .flat_map(|qb| self.qubit_wires(qb)); - let bit_candidates = bit_args - .first() - .into_iter() - .flat_map(|bit| self.bit_wires(bit)); + let qubit_candidates = if reg_count.qubits > 0 && !qubit_args.is_empty() { + itertools::Either::Left(self.qubit_wires(&qubit_args[0])) + } else { + itertools::Either::Right(std::iter::empty()) + }; + let bit_candidates = if reg_count.bits > 0 && !bit_args.is_empty() { + itertools::Either::Left(self.bit_wires(&bit_args[0])) + } else { + itertools::Either::Right(std::iter::empty()) + }; let candidates = qubit_candidates.chain(bit_candidates).collect_vec(); // The bits and qubits we expect the wire to contain. @@ -766,13 +768,13 @@ impl WireTracker { // Convert the wire type, if needed. let wire_data = &self.wires[&wire]; - let new_wire = config.transform_typed_value(wire, wire_data.ty(), ty, builder)?; + let found_wire_type = wire_data.ty(); + let new_wire = config.transform_typed_value(wire, found_wire_type, ty, builder)?; if wire == new_wire { Ok(FoundWire::Register(self.wires[&wire].clone())) } else { - let ty: Arc = wire_data.ty.clone(); - self.track_wire(new_wire, ty, wire_qubits, wire_bits)?; + self.track_wire(new_wire, Arc::new(ty.clone()), wire_qubits, wire_bits)?; self.mark_wire_outdated(wire); Ok(FoundWire::Register(self.wires[&new_wire].clone())) } diff --git a/tket/src/serialize/pytket/extension/bool.rs b/tket/src/serialize/pytket/extension/bool.rs index fe66d9828..c7d53798d 100644 --- a/tket/src/serialize/pytket/extension/bool.rs +++ b/tket/src/serialize/pytket/extension/bool.rs @@ -59,7 +59,7 @@ impl PytketEmitter for BoolEmitter { // variable inputs. If new [`BoolOp`]s are added that do not follow // this, the following code will need to be adjusted. let bit_count = (num_inputs + num_outputs) as usize; - let output_bits = (0..num_outputs).collect_vec(); + let output_bits = (num_inputs..(num_inputs + num_outputs)).collect_vec(); let mut expression = ClOperator::default(); expression.op = clop; expression.args = (0..num_inputs) @@ -67,7 +67,13 @@ impl PytketEmitter for BoolEmitter { .collect_vec(); let op = make_tk1_classical_expression(bit_count, &output_bits, &[], expression); - encoder.emit_node_command(node, hugr, EmitCommandOptions::new(), move |_| op)?; + encoder.emit_node_command( + node, + hugr, + // Output bits use new registers, so don't reuse any input bits. + EmitCommandOptions::new().reuse_bits(|_| vec![]), + move |_| op, + )?; Ok(EncodeStatus::Success) } diff --git a/tket/src/serialize/pytket/tests.rs b/tket/src/serialize/pytket/tests.rs index d3bbc436c..90bd6c733 100644 --- a/tket/src/serialize/pytket/tests.rs +++ b/tket/src/serialize/pytket/tests.rs @@ -16,7 +16,7 @@ use std::sync::Arc; use super::TKETDecode; use crate::TketOp; use crate::extension::TKET1_EXTENSION_ID; -use crate::extension::bool::{BoolOp, bool_type}; +use crate::extension::bool::{BoolOp, ConstBool, bool_type}; use crate::extension::rotation::{ConstRotation, RotationOp, rotation_type}; use crate::extension::sympy::SympyOpDef; use crate::metadata; @@ -287,6 +287,35 @@ fn circ_preset_qubits() -> Hugr { hugr } +/// A simple circuit with some preset input and output bit registers, +/// including multiple outputs for the same register. +#[fixture] +fn circ_preset_bits() -> Hugr { + let input_t = vec![bool_type()]; + let output_t = vec![bool_type(), bool_type(), bool_type()]; + let mut h = FunctionBuilder::new("preset_bits", Signature::new(input_t, output_t)).unwrap(); + + let [b0] = h.input_wires_arr(); + let b1 = h.add_load_value(ConstBool::new(false)); + let [b_and] = h + .add_dataflow_op(BoolOp::and, [b0, b1]) + .unwrap() + .outputs_arr(); + + let mut hugr = h.finish_hugr_with_outputs([b0, b_and, b0]).unwrap(); + + // A preset register for the first qubit output + hugr.set_metadata::( + hugr.entrypoint(), + vec![ElementId(String::from("b"), vec![1])] + .into_iter() + .map(register::Bit::from) + .collect_vec(), + ); + + hugr +} + /// A simple circuit with some input parameters #[fixture] fn circ_parameterized() -> Hugr { @@ -961,6 +990,7 @@ fn encoded_circuit_attributes(circ_measure_ancilla: Hugr) { #[rstest] #[case::meas_ancilla(circ_measure_ancilla(), CircuitRoundtripTestConfig::Default)] #[case::preset_qubits(circ_preset_qubits(), CircuitRoundtripTestConfig::Default)] +#[case::preset_bits(circ_preset_bits(), CircuitRoundtripTestConfig::Default)] #[case::preset_parameterized(circ_parameterized(), CircuitRoundtripTestConfig::Default)] // TODO: Should pass once CircBox encoding of DFGs is re-enabled. #[should_panic(expected = "Cannot encode subgraphs with nested structure")] @@ -1068,6 +1098,7 @@ fn fail_on_modified_hugr(circ_tk1_ops: Hugr) { #[rstest] #[case::meas_ancilla(circ_measure_ancilla(), 1, CircuitRoundtripTestConfig::Default)] #[case::preset_qubits(circ_preset_qubits(), 1, CircuitRoundtripTestConfig::Default)] +#[case::preset_bits(circ_preset_bits(), 1, CircuitRoundtripTestConfig::Default)] #[case::preset_parameterized(circ_parameterized(), 1, CircuitRoundtripTestConfig::Default)] #[case::nested_dfgs(circ_nested_dfgs(), 2, CircuitRoundtripTestConfig::Default)] #[case::flat_opaque(circ_tk1_ops(), 1, CircuitRoundtripTestConfig::Default)]