Skip to content

Commit 51e257d

Browse files
authored
bump default qubit normalization tolerance (#9014)
**Context:** User discovered that our sampling tolerance for normalization was too small. **Description of the Change:** Bumps the tolerance to 1e-6. **Benefits:** **Possible Drawbacks:** **Related GitHub Issues:** Fixes #9000 [sc-109916]
1 parent f0e8efd commit 51e257d

File tree

3 files changed

+43
-4
lines changed

3 files changed

+43
-4
lines changed

doc/releases/changelog-dev.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,10 @@
278278

279279
<h3>Bug fixes 🐛</h3>
280280

281+
* Bumps the tolerance used in determining whether the norm of the probabilities is sufficiently close to
282+
1 in Default Qubit.
283+
[(#9014)](https://github.com/PennyLaneAI/pennylane/pull/9014)
284+
281285
* Removes automatic unpacking of inner product resources in the resource representation of
282286
:class:`~.ops.op_math.Prod` for the graph-based decomposition system. This resolves a bug that
283287
prevents decompositions in this system from using nested operator products while reporting their

pennylane/devices/qubit/sampling.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -513,7 +513,7 @@ def _sample_probs_numpy(probs, shots, num_wires, is_state_batched, rng):
513513
rng = np.random.default_rng(rng)
514514
norm = qml.math.sum(probs, axis=-1)
515515
norm_err = qml.math.abs(norm - 1.0)
516-
cutoff = 1e-07
516+
cutoff = 1e-06
517517

518518
norm_err = norm_err if is_state_batched else norm_err[..., np.newaxis]
519519
if qml.math.any(norm_err > cutoff):

tests/devices/qubit/test_sampling.py

Lines changed: 38 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -729,7 +729,7 @@ def test_nan_shadow_expval(self, H, interface, shots):
729729
batched_state_not_normalized = np.stack(
730730
[
731731
np.array([[0, 0], [0, 1]]),
732-
np.array([[1.0000004, 0], [1, 0]]) / np.sqrt(2),
732+
np.array([[1.000004, 0], [1, 0]]) / np.sqrt(2),
733733
np.array([[1, 1], [1, 0.9999995]]) / 2,
734734
]
735735
)
@@ -1344,14 +1344,14 @@ def test_batched_sampling(self, seed):
13441344

13451345
def test_cutoff_edge_case_failure(self, seed):
13461346
"""Test sampling with probabilities just outside the cutoff."""
1347-
cutoff = 1e-7 # Assuming this is the cutoff used in sample_probs
1347+
cutoff = 1e-6 # Assuming this is the cutoff used in sample_probs
13481348
probs = np.array([0.5, 0.5 - 2 * cutoff])
13491349
with pytest.raises(ValueError, match=r"(?i)probabilities do not sum to 1"):
13501350
sample_probs(probs, shots=1000, num_wires=1, is_state_batched=False, rng=seed)
13511351

13521352
def test_batched_cutoff_edge_case_failure(self, seed):
13531353
"""Test sampling with probabilities just outside the cutoff."""
1354-
cutoff = 1e-7 # Assuming this is the cutoff used in sample_probs
1354+
cutoff = 1e-6 # Assuming this is the cutoff used in sample_probs
13551355
probs = np.array(
13561356
[
13571357
[0.5, 0.5 - 2 * cutoff],
@@ -1360,3 +1360,38 @@ def test_batched_cutoff_edge_case_failure(self, seed):
13601360
)
13611361
with pytest.raises(ValueError, match=r"(?i)probabilities do not sum to 1"):
13621362
sample_probs(probs, shots=1000, num_wires=1, is_state_batched=True, rng=seed)
1363+
1364+
@pytest.mark.jax
1365+
def test_no_error_with_jax_32_bit_precision(self):
1366+
"""Tests that a bug reported where jax 32 bit parameters caused a probability norm further from 1 then the initial cutoff.
1367+
1368+
See https://github.com/PennyLaneAI/pennylane/issues/9000 for the report.
1369+
"""
1370+
1371+
import jax # pylint: disable=import-outside-toplevel
1372+
1373+
feature_count = 2
1374+
1375+
key = jax.random.key(123)
1376+
key_inputs, key_params = jax.random.split(key)
1377+
1378+
inputs = jax.random.uniform(key_inputs, shape=(1450, 2))
1379+
params = jax.random.uniform(key_params, shape=(2, 3))
1380+
1381+
device = qml.device("default.qubit")
1382+
1383+
@qml.qnode(device)
1384+
def circuit(inputs, weights):
1385+
for i in range(feature_count):
1386+
qml.RY(inputs[:, i], wires=i)
1387+
for i in range(feature_count):
1388+
qml.RX(weights[i, 3], wires=i)
1389+
qml.RY(weights[i, 4], wires=i)
1390+
qml.RX(weights[i, 5], wires=i)
1391+
qml.CNOT(wires=[i, (i + 1) % feature_count])
1392+
for i in range(1, feature_count - 1):
1393+
qml.CNOT(wires=[i, (i + 1)])
1394+
return qml.expval(qml.sum(*[qml.PauliZ(i) for i in range(feature_count)]))
1395+
1396+
# just testing it runs without error
1397+
_ = qml.set_shots(circuit, 8)(inputs, params)

0 commit comments

Comments
 (0)