@@ -729,7 +729,7 @@ def test_nan_shadow_expval(self, H, interface, shots):
729729batched_state_not_normalized = np .stack (
730730 [
731731 np .array ([[0 , 0 ], [0 , 1 ]]),
732- np .array ([[1.0000004 , 0 ], [1 , 0 ]]) / np .sqrt (2 ),
732+ np .array ([[1.000004 , 0 ], [1 , 0 ]]) / np .sqrt (2 ),
733733 np .array ([[1 , 1 ], [1 , 0.9999995 ]]) / 2 ,
734734 ]
735735)
@@ -1344,14 +1344,14 @@ def test_batched_sampling(self, seed):
13441344
13451345 def test_cutoff_edge_case_failure (self , seed ):
13461346 """Test sampling with probabilities just outside the cutoff."""
1347- cutoff = 1e-7 # Assuming this is the cutoff used in sample_probs
1347+ cutoff = 1e-6 # Assuming this is the cutoff used in sample_probs
13481348 probs = np .array ([0.5 , 0.5 - 2 * cutoff ])
13491349 with pytest .raises (ValueError , match = r"(?i)probabilities do not sum to 1" ):
13501350 sample_probs (probs , shots = 1000 , num_wires = 1 , is_state_batched = False , rng = seed )
13511351
13521352 def test_batched_cutoff_edge_case_failure (self , seed ):
13531353 """Test sampling with probabilities just outside the cutoff."""
1354- cutoff = 1e-7 # Assuming this is the cutoff used in sample_probs
1354+ cutoff = 1e-6 # Assuming this is the cutoff used in sample_probs
13551355 probs = np .array (
13561356 [
13571357 [0.5 , 0.5 - 2 * cutoff ],
@@ -1360,3 +1360,38 @@ def test_batched_cutoff_edge_case_failure(self, seed):
13601360 )
13611361 with pytest .raises (ValueError , match = r"(?i)probabilities do not sum to 1" ):
13621362 sample_probs (probs , shots = 1000 , num_wires = 1 , is_state_batched = True , rng = seed )
1363+
1364+ @pytest .mark .jax
1365+ def test_no_error_with_jax_32_bit_precision (self ):
1366+ """Tests that a bug reported where jax 32 bit parameters caused a probability norm further from 1 then the initial cutoff.
1367+
1368+ See https://github.com/PennyLaneAI/pennylane/issues/9000 for the report.
1369+ """
1370+
1371+ import jax # pylint: disable=import-outside-toplevel
1372+
1373+ feature_count = 2
1374+
1375+ key = jax .random .key (123 )
1376+ key_inputs , key_params = jax .random .split (key )
1377+
1378+ inputs = jax .random .uniform (key_inputs , shape = (1450 , 2 ))
1379+ params = jax .random .uniform (key_params , shape = (2 , 3 ))
1380+
1381+ device = qml .device ("default.qubit" )
1382+
1383+ @qml .qnode (device )
1384+ def circuit (inputs , weights ):
1385+ for i in range (feature_count ):
1386+ qml .RY (inputs [:, i ], wires = i )
1387+ for i in range (feature_count ):
1388+ qml .RX (weights [i , 3 ], wires = i )
1389+ qml .RY (weights [i , 4 ], wires = i )
1390+ qml .RX (weights [i , 5 ], wires = i )
1391+ qml .CNOT (wires = [i , (i + 1 ) % feature_count ])
1392+ for i in range (1 , feature_count - 1 ):
1393+ qml .CNOT (wires = [i , (i + 1 )])
1394+ return qml .expval (qml .sum (* [qml .PauliZ (i ) for i in range (feature_count )]))
1395+
1396+ # just testing it runs without error
1397+ _ = qml .set_shots (circuit , 8 )(inputs , params )
0 commit comments