Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions crates/qpy/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ mod value;
pub fn qpy(module: &Bound<PyModule>) -> PyResult<()> {
module.add_function(wrap_pyfunction!(circuit_writer::py_write_circuit, module)?)?;
module.add_function(wrap_pyfunction!(circuit_reader::py_read_circuit, module)?)?;
module.add_function(wrap_pyfunction!(value::py_write_values, module)?)?;
module.add_function(wrap_pyfunction!(value::py_read_values, module)?)?;
Ok(())
}

Expand Down
117 changes: 113 additions & 4 deletions crates/qpy/src/value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,15 @@ use binrw::{BinRead, BinWrite, Endian, binrw};
use hashbrown::HashMap;
use pyo3::exceptions::PyValueError;
use pyo3::prelude::*;
use pyo3::types::PyAny;
use pyo3::types::{PyAny, PyBytes};

use qiskit_circuit::bit::{ClassicalRegister, ShareableClbit};
use qiskit_circuit::circuit_data::CircuitData;
use qiskit_circuit::classical::expr::{Expr, Stretch, Var};
use qiskit_circuit::classical::types::Type;
use qiskit_circuit::converters::QuantumCircuitData;
use qiskit_circuit::duration::Duration;
use qiskit_circuit::operations::{ForCollection, OperationRef, PyRange};
use qiskit_circuit::operations::{ForCollection, OperationRef, Param, PyRange};
use qiskit_circuit::packed_instruction::PackedOperation;
use qiskit_circuit::parameter::parameter_expression::ParameterExpression;
use qiskit_circuit::parameter::symbol_expr::Symbol;
Expand All @@ -42,7 +42,8 @@ use crate::params::{
unpack_parameter_vector, unpack_symbol,
};
use crate::py_methods::{
py_deserialize_numpy_object, py_pack_modifier, py_serialize_numpy_object, py_unpack_modifier,
py_convert_from_generic_value, py_convert_to_generic_value, py_deserialize_numpy_object,
py_pack_modifier, py_serialize_numpy_object, py_unpack_modifier,
};
use crate::{QpyError, UnsupportedFeatureForVersion};

Expand All @@ -52,7 +53,7 @@ use std::fmt::Debug;
use std::io::Cursor;
use uuid::Uuid;

pub const QPY_VERSION: u32 = 15;
pub const QPY_VERSION: u32 = 17;

// Standard char representation of register types: 'q' qreg, 'c' for creg
#[binrw]
Expand Down Expand Up @@ -112,6 +113,23 @@ pub struct QPYWriteData<'a> {
pub annotation_handler: AnnotationHandler<'a>,
}

impl<'a> QPYWriteData<'a> {
/// Create a default `QPYWriteData` with the given circuit and annotation factories,
/// using the current QPY version and an empty standalone-var index map.
pub fn default(
circuit_data: &'a CircuitData,
version: u32,
annotation_factories: &'a Bound<'a, pyo3::types::PyDict>,
) -> Self {
QPYWriteData {
circuit_data,
version,
standalone_var_indices: HashMap::new(),
annotation_handler: AnnotationHandler::new(annotation_factories),
}
}
}

// Data that is needed globally while reading the circuit
#[derive(Debug)]
pub struct QPYReadData<'a> {
Expand All @@ -124,6 +142,25 @@ pub struct QPYReadData<'a> {
pub annotation_handler: AnnotationHandler<'a>,
}

impl<'a> QPYReadData<'a> {
/// Create a default `QPYReadData` with the given circuit, annotation factories and symengine flag,
/// using the current QPY version and empty maps for vars, stretches and vectors.
pub fn default(
circuit_data: &'a mut CircuitData,
annotation_factories: &'a Bound<'a, pyo3::types::PyDict>,
) -> Self {
QPYReadData {
circuit_data,
version: QPY_VERSION,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want the read side to also be able to set the qpy version?
\

Copy link
Copy Markdown
Contributor Author

@gadial gadial Mar 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We must, this is another oversight by me, and it's a real shame because I actually thought about this but assumed the version is packaged in the qpy, since I just handled that in #15749 :-(

done in 645d0b6

use_symengine: false,
standalone_vars: HashMap::new(),
standalone_stretches: HashMap::new(),
vectors: HashMap::new(),
annotation_handler: AnnotationHandler::new(annotation_factories),
}
}
}

// this is how tags for various value types are encoded in a QPY file
#[binrw]
#[brw(repr = u8)]
Expand Down Expand Up @@ -868,3 +905,75 @@ pub(crate) fn load_param_register_value(
)))
}
}

/// Write a list of QPY-serializable values to a file object.
///
/// Args:
/// file_obj: The file object to write to.
/// values: The list of values to serialize and write.
///
#[pyfunction]
#[pyo3(name = "write_values")]
Comment thread
ihincks marked this conversation as resolved.
#[pyo3(signature = (file_obj, values, version=None))]
pub(crate) fn py_write_values(
py: Python,
file_obj: &Bound<pyo3::types::PyAny>,
values: &Bound<pyo3::types::PyAny>,
version: Option<u32>,
) -> PyResult<usize> {
let version = version.unwrap_or(QPY_VERSION);
let dummy_circuit = CircuitData::new(None, None, Param::Float(0.0))?;
let empty_dict = pyo3::types::PyDict::new(py);
let qpy_data = QPYWriteData::default(&dummy_circuit, version, &empty_dict);

let mut elements = Vec::new();
for item in values.try_iter()? {
let generic_value = py_convert_to_generic_value(&(item?))?;
let (type_key, data) = serialize_generic_value(&generic_value, &qpy_data)?;
elements.push(GenericDataPack { type_key, data });
}

let sequence_pack = GenericDataSequencePack { elements };
let serialized = serialize(&sequence_pack);
file_obj.call_method1("write", (PyBytes::new(py, &serialized),))?;
Ok(serialized.len())
}

/// Read a list of QPY-serializable values from a file object.
///
/// Args:
///
/// file_obj: The file object to read from. The file's cursor will be advanced by the number of bytes read.
/// Returns:
/// A list of deserialized values read from the file.
#[pyfunction]
#[pyo3(name = "read_values")]
#[pyo3(signature = (file_obj))]
pub(crate) fn py_read_values(
py: Python,
file_obj: &Bound<pyo3::types::PyAny>,
) -> PyResult<Py<pyo3::types::PyAny>> {
use pyo3::types::{PyBytes, PyList};
use qiskit_circuit::circuit_data::CircuitData;
use qiskit_circuit::operations::Param;

let pos = file_obj.call_method0("tell")?.extract::<usize>()?;
let bytes_obj = file_obj.call_method0("read")?;
let raw_bytes: &[u8] = bytes_obj.cast::<PyBytes>()?.as_bytes();

let (sequence_pack, bytes_read) = deserialize::<GenericDataSequencePack>(raw_bytes)?;

let mut dummy_circuit = CircuitData::new(None, None, Param::Float(0.0))?;
let empty_dict = pyo3::types::PyDict::new(py);
let mut qpy_data = QPYReadData::default(&mut dummy_circuit, &empty_dict);

let mut result_list = Vec::with_capacity(sequence_pack.elements.len());
for data_pack in &sequence_pack.elements {
let generic_value = unpack_generic_value(data_pack, &mut qpy_data)?;
let py_obj = py_convert_from_generic_value(&generic_value)?;
result_list.push(py_obj);
}

file_obj.call_method1("seek", (pos + bytes_read,))?;
Ok(PyList::new(py, result_list)?.into_any().unbind())
}
134 changes: 133 additions & 1 deletion test/python/qpy/test_serialize_value_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import io
from test import QiskitTestCase
from qiskit.circuit import Parameter, QuantumCircuit
from qiskit.circuit import Parameter, ParameterVector, QuantumCircuit
from qiskit import qpy
from qiskit.quantum_info import SparseObservable
from qiskit.quantum_info.operators import SparsePauliOp
Expand Down Expand Up @@ -106,3 +106,135 @@ def test_pauli_evolution_operator_list(self):
qc_from_qpy = qpy.load(container)[0]

self.assertEqual(circuit, qc_from_qpy)


class TestWriteReadValueList(QiskitTestCase):
"""Tests for the write_values / read_values Rust QPY functions."""

def _roundtrip(self, values):
"""Helper: write then read back a list of values."""
from qiskit._accelerate import qpy as _qpy

buf = io.BytesIO()
_qpy.write_values(buf, values)
buf.seek(0)
return _qpy.read_values(buf)

# ------------------------------------------------------------------
# Parameter / ParameterExpression
# ------------------------------------------------------------------

def test_parameter_roundtrip(self):
"""A bare Parameter survives a write/read cycle."""
a = Parameter("a")
result = self._roundtrip([a])
self.assertEqual(len(result), 1)
self.assertEqual(result[0], a)

def test_parameter_vector_element_roundtrip(self):
"""A ParameterVector element survives a write/read cycle."""
v = ParameterVector("v", 3)
result = self._roundtrip([v[0], v[2]])
self.assertEqual(len(result), 2)
self.assertEqual(result[0], v[0])
self.assertEqual(result[1], v[2])

def test_parameter_expression_roundtrip(self):
"""A compound ParameterExpression survives a write/read cycle."""
a = Parameter("a")
b = Parameter("b")
expr = a * 2 + b / 3

result = self._roundtrip([expr])
self.assertEqual(len(result), 1)
self.assertEqual(result[0], expr)

def test_multiple_parameter_expressions(self):
"""Multiple ParameterExpressions in one list all round-trip correctly."""
a = Parameter("a")
b = Parameter("b")
exprs = [a, b, a + b, a * b - 1.5, b**a]

result = self._roundtrip(exprs)
self.assertEqual(len(result), len(exprs))
for original, recovered in zip(exprs, result):
self.assertEqual(original, recovered)

def test_float_roundtrip(self):
"""Float values survive a write/read cycle."""
values = [0.0, 1.5, -3.14, float("inf")]
result = self._roundtrip(values)
self.assertEqual(len(result), len(values))
for original, recovered in zip(values, result):
self.assertEqual(original, recovered)

def test_int_roundtrip(self):
"""Integer values survive a write/read cycle."""
values = [0, 1, -42, 2**31]
result = self._roundtrip(values)
self.assertEqual(len(result), len(values))
for original, recovered in zip(values, result):
self.assertEqual(original, recovered)

def test_complex_roundtrip(self):
"""Complex values survive a write/read cycle."""
values = [1 + 2j, -3.5 + 0j, 0 - 1j]
result = self._roundtrip(values)
self.assertEqual(len(result), len(values))
for original, recovered in zip(values, result):
self.assertEqual(original, recovered)

def test_range_roundtrip(self):
"""A Python range survives a write/read cycle."""
r = range(2, 10, 3)
result = self._roundtrip([r])
self.assertEqual(len(result), 1)
self.assertEqual(list(result[0]), list(r))

def test_tuple_of_mixed_scalars(self):
"""A tuple of mixed scalar types survives a write/read cycle."""
t = (1, 2.5, 3 + 4j)
result = self._roundtrip([t])
self.assertEqual(len(result), 1)
recovered = result[0]
self.assertEqual(recovered[0], 1)
self.assertAlmostEqual(recovered[1], 2.5)
self.assertEqual(recovered[2], 3 + 4j)

def test_circuit_roundtrip(self):
"""A QuantumCircuit survives a write/read cycle."""
qc = QuantumCircuit(2)
qc.h(0)
qc.cx(0, 1)

result = self._roundtrip([qc])
self.assertEqual(len(result), 1)
self.assertEqual(result[0], qc)

def test_parameterized_circuit_roundtrip(self):
"""A parameterized QuantumCircuit survives a write/read cycle."""
theta = Parameter("theta")
qc = QuantumCircuit(1)
qc.rz(theta, 0)

result = self._roundtrip([qc])
self.assertEqual(len(result), 1)
self.assertEqual(result[0], qc)

def test_mixed_value_list(self):
"""A heterogeneous list of values all round-trip correctly."""
a = Parameter("a")
qc = QuantumCircuit(1)
qc.x(0)
values = [1, 2.5, 1 + 2j, a, a * 3, range(5), (0, 1.0), qc]

result = self._roundtrip(values)
self.assertEqual(len(result), len(values))
self.assertEqual(result[0], 1)
self.assertAlmostEqual(result[1], 2.5)
self.assertEqual(result[2], 1 + 2j)
self.assertEqual(result[3], a)
self.assertEqual(result[4], a * 3)
self.assertEqual(result[5], range(5))
self.assertEqual(result[6], (0, 1.0))
self.assertEqual(result[7], qc)