diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md
index 498230f4c58..364e74b0f21 100644
--- a/doc/releases/changelog-dev.md
+++ b/doc/releases/changelog-dev.md
@@ -249,6 +249,9 @@
Internal changes ⚙️
+* `qml.counts` of mid circuit measurements can now be captured into jaxpr.
+ [(#9022)](https://github.com/PennyLaneAI/pennylane/pull/9022)
+
* Pass-by-pass specs now use ``BoundTransform.tape_transform`` rather than the deprecated ``BoundTransform.transform``.
Additionally, several internal comments have been updated to bring specs in line with the new ``CompilePipeline`` class.
[(#9012)](https://github.com/PennyLaneAI/pennylane/pull/9012)
diff --git a/pennylane/measurements/counts.py b/pennylane/measurements/counts.py
index 665b26e2ba7..ebf5e35bc37 100644
--- a/pennylane/measurements/counts.py
+++ b/pennylane/measurements/counts.py
@@ -339,6 +339,34 @@ def _abstract_eval(*args, has_eigvals=False, all_outcomes=False):
return keys, values
+# pylint: disable=protected-access, unused-argument
+if CountsMP._mcm_primitive is not None:
+
+ CountsMP._mcm_primitive.multiple_results = True
+
+ @CountsMP._mcm_primitive.def_impl
+ def _mcm_impl(*args, **kwargs):
+ raise NotImplementedError("Counts has no execution implementation with program capture.")
+
+ def _mcm_keys_eval(n_wires, has_eigvals=False, shots=None, num_device_wires=0):
+ if shots is None:
+ raise ValueError("finite shots are required to use CountsMP")
+ return (2**n_wires,), int
+
+ def _mcm_values_eval(n_wires, has_eigvals=False, shots=None, num_device_wires=0):
+ if shots is None:
+ raise ValueError("finite shots are required to use CountsMP")
+ return (2**n_wires,), int
+
+ abstract_mp = _get_abstract_measurement()
+
+ @CountsMP._mcm_primitive.def_abstract_eval
+ def _mcm_abstract_eval(*mcms, single_mcm, all_outcomes=False):
+ keys = abstract_mp(_mcm_keys_eval, n_wires=len(mcms), has_eigvals=False)
+ values = abstract_mp(_mcm_values_eval, n_wires=len(mcms), has_eigvals=False)
+ return keys, values
+
+
def counts(
op=None,
wires=None,
diff --git a/pennylane/measurements/measurements.py b/pennylane/measurements/measurements.py
index ffb49ebb197..a05a99570ba 100644
--- a/pennylane/measurements/measurements.py
+++ b/pennylane/measurements/measurements.py
@@ -108,8 +108,10 @@ def _primitive_bind_call(cls, obs=None, wires=None, eigvals=None, id=None, **kwa
if isinstance(getattr(obs, "aval", None), _get_abstract_operator()):
return cls._obs_primitive.bind(obs, **kwargs)
if isinstance(obs, (list, tuple)):
- return cls._mcm_primitive.bind(*obs, single_mcm=False, **kwargs) # iterable of mcms
- return cls._mcm_primitive.bind(obs, single_mcm=True, **kwargs) # single mcm
+ out = cls._mcm_primitive.bind(*obs, single_mcm=False, **kwargs) # iterable of mcms
+ return tuple(out) if isinstance(out, list) else out
+ out = cls._mcm_primitive.bind(obs, single_mcm=True, **kwargs) # single mcm
+ return tuple(out) if isinstance(out, list) else out
# pylint: disable=unused-argument
@classmethod
diff --git a/tests/capture/test_measurements_capture.py b/tests/capture/test_measurements_capture.py
index 95c61265a6c..5151727f899 100644
--- a/tests/capture/test_measurements_capture.py
+++ b/tests/capture/test_measurements_capture.py
@@ -109,6 +109,15 @@ def test_counts_no_implementation(self):
):
qml.counts()
+ def test_counts_no_implementation_mcm(self):
+ """Test that counts of an mcm can't be measured and raises a NotImplementedError."""
+
+ with pytest.raises(
+ NotImplementedError,
+ match=r"Counts has no execution implementation with program capture.",
+ ):
+ qml.counts(2)
+
def test_warning_about_all_outcomes(self):
"""Test a warning is raised about all_outcomes=False"""
@@ -148,6 +157,40 @@ def f():
with pytest.raises(ValueError, match="finite shots are required"):
jaxpr.outvars[1].aval.abstract_eval(num_device_wires=0, shots=None)
+ @pytest.mark.parametrize("num_mcms", [1, 2])
+ def test_counts_mcm_capture_jaxpr(self, num_mcms):
+ """Test that counts can be captured into jaxpr."""
+
+ def f():
+ ms = [qml.measure(0) for _ in range(num_mcms)]
+ return qml.counts(ms)
+
+ jaxpr = jax.make_jaxpr(f)()
+ jaxpr = jaxpr.jaxpr
+
+ assert len(jaxpr.outvars) == 2
+
+ assert jaxpr.eqns[-1].primitive == CountsMP._mcm_primitive
+ assert len(jaxpr.eqns[0].invars) == 1
+
+ keys = jaxpr.outvars[0].aval
+ assert isinstance(keys, AbstractMeasurement)
+ keys_shape = keys.abstract_eval(num_device_wires=0, shots=50)
+ assert keys_shape[0] == (2**num_mcms,)
+ assert keys_shape[1] == int
+
+ with pytest.raises(ValueError, match="finite shots are required"):
+ keys.abstract_eval(num_device_wires=0, shots=None)
+
+ values = jaxpr.outvars[1].aval
+ assert isinstance(values, AbstractMeasurement)
+ values_shape = values.abstract_eval(num_device_wires=0, shots=50)
+ assert values_shape[0] == (2**num_mcms,)
+ assert values_shape[1] == int
+
+ with pytest.raises(ValueError, match="finite shots are required"):
+ values.abstract_eval(num_device_wires=0, shots=None)
+
def test_counts_capture_jaxpr_all_wires(self):
"""Test that counts can be captured into jaxpr."""
@@ -203,6 +246,32 @@ def c():
jaxpr = jax.make_jaxpr(w)().jaxpr
assert len(jaxpr.outvars) == 3
+ def test_qnode_integration_mcms(self):
+ """Test that counts can integrate with capturing a qnode."""
+
+ def w():
+ @qml.qnode(qml.device("default.qubit", wires=2), shots=10)
+ def c():
+ m0 = qml.measure(0)
+ return qml.counts(m0), qml.sample()
+
+ r = c()
+ assert isinstance(r, tuple)
+ assert len(r) == 2
+ assert isinstance(r[0], tuple)
+ assert len(r[0]) == 2
+ for i in (0, 1):
+ assert r[0][i].shape == (2,)
+ assert r[0][i].dtype == jax.numpy.int64
+
+ assert r[1].shape == (10, 2)
+ assert r[1].dtype == jax.numpy.int64
+
+ return r
+
+ jaxpr = jax.make_jaxpr(w)().jaxpr
+ assert len(jaxpr.outvars) == 3
+
def test_primitive_none_behavior():
"""Test that if the obs primitive is None, the measurement can still