diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 498230f4c58..364e74b0f21 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -249,6 +249,9 @@

Internal changes ⚙️

+* `qml.counts` of mid circuit measurements can now be captured into jaxpr. + [(#9022)](https://github.com/PennyLaneAI/pennylane/pull/9022) + * Pass-by-pass specs now use ``BoundTransform.tape_transform`` rather than the deprecated ``BoundTransform.transform``. Additionally, several internal comments have been updated to bring specs in line with the new ``CompilePipeline`` class. [(#9012)](https://github.com/PennyLaneAI/pennylane/pull/9012) diff --git a/pennylane/measurements/counts.py b/pennylane/measurements/counts.py index 665b26e2ba7..ebf5e35bc37 100644 --- a/pennylane/measurements/counts.py +++ b/pennylane/measurements/counts.py @@ -339,6 +339,34 @@ def _abstract_eval(*args, has_eigvals=False, all_outcomes=False): return keys, values +# pylint: disable=protected-access, unused-argument +if CountsMP._mcm_primitive is not None: + + CountsMP._mcm_primitive.multiple_results = True + + @CountsMP._mcm_primitive.def_impl + def _mcm_impl(*args, **kwargs): + raise NotImplementedError("Counts has no execution implementation with program capture.") + + def _mcm_keys_eval(n_wires, has_eigvals=False, shots=None, num_device_wires=0): + if shots is None: + raise ValueError("finite shots are required to use CountsMP") + return (2**n_wires,), int + + def _mcm_values_eval(n_wires, has_eigvals=False, shots=None, num_device_wires=0): + if shots is None: + raise ValueError("finite shots are required to use CountsMP") + return (2**n_wires,), int + + abstract_mp = _get_abstract_measurement() + + @CountsMP._mcm_primitive.def_abstract_eval + def _mcm_abstract_eval(*mcms, single_mcm, all_outcomes=False): + keys = abstract_mp(_mcm_keys_eval, n_wires=len(mcms), has_eigvals=False) + values = abstract_mp(_mcm_values_eval, n_wires=len(mcms), has_eigvals=False) + return keys, values + + def counts( op=None, wires=None, diff --git a/pennylane/measurements/measurements.py b/pennylane/measurements/measurements.py index ffb49ebb197..a05a99570ba 100644 --- a/pennylane/measurements/measurements.py +++ b/pennylane/measurements/measurements.py @@ -108,8 +108,10 @@ def _primitive_bind_call(cls, obs=None, wires=None, eigvals=None, id=None, **kwa if isinstance(getattr(obs, "aval", None), _get_abstract_operator()): return cls._obs_primitive.bind(obs, **kwargs) if isinstance(obs, (list, tuple)): - return cls._mcm_primitive.bind(*obs, single_mcm=False, **kwargs) # iterable of mcms - return cls._mcm_primitive.bind(obs, single_mcm=True, **kwargs) # single mcm + out = cls._mcm_primitive.bind(*obs, single_mcm=False, **kwargs) # iterable of mcms + return tuple(out) if isinstance(out, list) else out + out = cls._mcm_primitive.bind(obs, single_mcm=True, **kwargs) # single mcm + return tuple(out) if isinstance(out, list) else out # pylint: disable=unused-argument @classmethod diff --git a/tests/capture/test_measurements_capture.py b/tests/capture/test_measurements_capture.py index 95c61265a6c..5151727f899 100644 --- a/tests/capture/test_measurements_capture.py +++ b/tests/capture/test_measurements_capture.py @@ -109,6 +109,15 @@ def test_counts_no_implementation(self): ): qml.counts() + def test_counts_no_implementation_mcm(self): + """Test that counts of an mcm can't be measured and raises a NotImplementedError.""" + + with pytest.raises( + NotImplementedError, + match=r"Counts has no execution implementation with program capture.", + ): + qml.counts(2) + def test_warning_about_all_outcomes(self): """Test a warning is raised about all_outcomes=False""" @@ -148,6 +157,40 @@ def f(): with pytest.raises(ValueError, match="finite shots are required"): jaxpr.outvars[1].aval.abstract_eval(num_device_wires=0, shots=None) + @pytest.mark.parametrize("num_mcms", [1, 2]) + def test_counts_mcm_capture_jaxpr(self, num_mcms): + """Test that counts can be captured into jaxpr.""" + + def f(): + ms = [qml.measure(0) for _ in range(num_mcms)] + return qml.counts(ms) + + jaxpr = jax.make_jaxpr(f)() + jaxpr = jaxpr.jaxpr + + assert len(jaxpr.outvars) == 2 + + assert jaxpr.eqns[-1].primitive == CountsMP._mcm_primitive + assert len(jaxpr.eqns[0].invars) == 1 + + keys = jaxpr.outvars[0].aval + assert isinstance(keys, AbstractMeasurement) + keys_shape = keys.abstract_eval(num_device_wires=0, shots=50) + assert keys_shape[0] == (2**num_mcms,) + assert keys_shape[1] == int + + with pytest.raises(ValueError, match="finite shots are required"): + keys.abstract_eval(num_device_wires=0, shots=None) + + values = jaxpr.outvars[1].aval + assert isinstance(values, AbstractMeasurement) + values_shape = values.abstract_eval(num_device_wires=0, shots=50) + assert values_shape[0] == (2**num_mcms,) + assert values_shape[1] == int + + with pytest.raises(ValueError, match="finite shots are required"): + values.abstract_eval(num_device_wires=0, shots=None) + def test_counts_capture_jaxpr_all_wires(self): """Test that counts can be captured into jaxpr.""" @@ -203,6 +246,32 @@ def c(): jaxpr = jax.make_jaxpr(w)().jaxpr assert len(jaxpr.outvars) == 3 + def test_qnode_integration_mcms(self): + """Test that counts can integrate with capturing a qnode.""" + + def w(): + @qml.qnode(qml.device("default.qubit", wires=2), shots=10) + def c(): + m0 = qml.measure(0) + return qml.counts(m0), qml.sample() + + r = c() + assert isinstance(r, tuple) + assert len(r) == 2 + assert isinstance(r[0], tuple) + assert len(r[0]) == 2 + for i in (0, 1): + assert r[0][i].shape == (2,) + assert r[0][i].dtype == jax.numpy.int64 + + assert r[1].shape == (10, 2) + assert r[1].dtype == jax.numpy.int64 + + return r + + jaxpr = jax.make_jaxpr(w)().jaxpr + assert len(jaxpr.outvars) == 3 + def test_primitive_none_behavior(): """Test that if the obs primitive is None, the measurement can still