Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update parameters to accommodate with RZZ constraints #2126

Draft
wants to merge 31 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
7dbfe26
wrote a test for rzz conversion
yaelbh Feb 4, 2025
0aee8a7
switching to pubs
yaelbh Feb 5, 2025
1f19512
a beginning of the change in the pub
yaelbh Feb 6, 2025
4be8666
global phase done
yaelbh Feb 6, 2025
763da6c
fixes
yaelbh Feb 6, 2025
b2b157c
fixed test
yaelbh Feb 6, 2025
e660cb8
handling rz, rx, rzz and removing global phase
yaelbh Feb 6, 2025
95f60a5
update test to ignore global phase
yaelbh Feb 6, 2025
cfa60ef
black
yaelbh Feb 6, 2025
7fd4b34
black
yaelbh Feb 6, 2025
b4d6729
lint
yaelbh Feb 6, 2025
aa29a71
some fixes
yaelbh Feb 6, 2025
b50a77a
bug fix
yaelbh Feb 6, 2025
64683c2
preparing to test many inputs
yaelbh Feb 6, 2025
a86051e
test more cases
yaelbh Feb 6, 2025
b059808
make the test a bit more interesting; still need to implement for dyn…
yaelbh Feb 6, 2025
b29cd64
enhanced test
yaelbh Feb 9, 2025
2334016
beginning of a test of dynamic (still missing testing of qubit indices)
yaelbh Feb 9, 2025
27e432a
skip test of dynamic circuits
yaelbh Feb 9, 2025
337039d
remove debug prints
yaelbh Feb 9, 2025
cb0a21c
black
yaelbh Feb 9, 2025
1c09dc3
lint
yaelbh Feb 9, 2025
7a4aa09
lint
yaelbh Feb 9, 2025
f873e69
Merge branch 'main' into fixrzzpubs
yaelbh Feb 9, 2025
47cad78
lint
yaelbh Feb 9, 2025
ffa3850
mypy
yaelbh Feb 9, 2025
8632e32
accurately copy the circuit's registers
yaelbh Feb 13, 2025
002f79a
lint
yaelbh Feb 13, 2025
6aa30f4
Merge branch 'main' into fixrzzpubs
yaelbh Feb 13, 2025
437e23a
empty commit to rerun CI
yaelbh Feb 17, 2025
aeb7a90
Merge branch 'main' into fixrzzpubs
yaelbh Feb 17, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 149 additions & 3 deletions qiskit_ibm_runtime/transpiler/passes/basis/fold_rzz_angle.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,19 @@

"""Pass to wrap Rzz gate angle in calibrated range of 0-pi/2."""

from typing import Tuple
from typing import Tuple, Union
from math import pi
from operator import mod
from itertools import chain

from qiskit.converters import dag_to_circuit, circuit_to_dag
from qiskit.circuit.library.standard_gates import RZZGate, RZGate, XGate, GlobalPhaseGate
from qiskit.circuit.parameterexpression import ParameterExpression
from qiskit.circuit import CircuitInstruction, Parameter, ParameterExpression
from qiskit.circuit.library.standard_gates import RZZGate, RZGate, XGate, GlobalPhaseGate, RXGate
from qiskit.circuit import Qubit, ControlFlowOp
from qiskit.dagcircuit import DAGCircuit
from qiskit.transpiler.basepasses import TransformationPass
from qiskit.primitives.containers.estimator_pub import EstimatorPub, EstimatorPubLike
from qiskit.primitives.containers.sampler_pub import SamplerPub, SamplerPubLike

import numpy as np

Expand Down Expand Up @@ -244,3 +248,145 @@ def _quad4(angle: float, qubits: Tuple[Qubit, ...]) -> DAGCircuit:
check=False,
)
return new_dag


def convert_to_rzz_valid_pub(
program_id: str, pub: Union[SamplerPubLike, EstimatorPubLike]
) -> Union[SamplerPub, EstimatorPub]:
"""Return a pub which is compatible with Rzz constraints"""
if program_id == "sampler":
pub = SamplerPub.coerce(pub)
elif program_id == "estimator":
pub = EstimatorPub.coerce(pub)
else:
raise ValueError(f"Unknown program id {program_id}")

val_data = pub.parameter_values.data
pub_params = np.array(list(chain.from_iterable(val_data)))
# first axis will be over flattened shape, second axis over circuit parameters
arr = pub.parameter_values.ravel().as_array()

new_circ = pub.circuit.copy_empty_like()
new_data = []
rzz_count = 0

for instruction in pub.circuit.data:
operation = instruction.operation
if operation.name != "rzz" or not isinstance(
(param_exp := instruction.operation.params[0]), ParameterExpression
):
new_data.append(instruction)
continue

param_names = [param.name for param in param_exp.parameters]

col_indices = [np.where(pub_params == param_name)[0][0] for param_name in param_names]
# col_indices is the indices of columns in the parameter value array that have to be checked

# project only to the parameters that have to be checked
projected_arr = arr[:, col_indices]
num_param_sets = len(projected_arr)

rz_angles = np.zeros(num_param_sets)
rx_angles = np.zeros(num_param_sets)

for idx, row in enumerate(projected_arr):
angle = float(param_exp.bind(dict(zip(param_exp.parameters, row))))

if (angle + pi / 2) % (2 * pi) >= pi:
rz_angles[idx] = pi
else:
rz_angles[idx] = 0

if angle % pi >= pi / 2:
rx_angles[idx] = pi
else:
rx_angles[idx] = 0

rzz_count += 1
param_prefix = f"rzz_{rzz_count}_"
qubits = instruction.qubits

is_rz = False
if any(not np.isclose(rz_angle, 0) for rz_angle in rz_angles):
is_rz = True
if all(np.isclose(rz_angle, pi) for rz_angle in rz_angles):
new_data.append(
CircuitInstruction(
RZGate(pi),
(qubits[0],),
)
)
new_data.append(
CircuitInstruction(
RZGate(pi),
(qubits[1],),
)
)
else:
param_rz = Parameter(f"{param_prefix}rz")
new_data.append(
CircuitInstruction(
RZGate(param_rz),
(qubits[0],),
)
)
new_data.append(
CircuitInstruction(
RZGate(param_rz),
(qubits[1],),
)
)
val_data[f"{param_prefix}rz"] = rz_angles

is_rx = False
is_x = False
if any(not np.isclose(rx_angle, 0) for rx_angle in rx_angles):
is_rx = True
if all(np.isclose(rx_angle, pi) for rx_angle in rx_angles):
is_x = True
new_data.append(
CircuitInstruction(
XGate(),
(qubits[0],),
)
)
else:
is_x = False
param_rx = Parameter(f"{param_prefix}rx")
new_data.append(
CircuitInstruction(
RXGate(param_rx),
(qubits[0],),
)
)
val_data[f"{param_prefix}rx"] = rx_angles

if is_rz or is_rx:
rzz_angle = pi / 2 - (param_exp._apply_operation(mod, pi) - pi / 2).abs()
new_data.append(CircuitInstruction(RZZGate(rzz_angle), qubits))
else:
new_data.append(instruction)

if is_rx:
if is_x:
new_data.append(
CircuitInstruction(
XGate(),
(qubits[0],),
)
)
else:
new_data.append(
CircuitInstruction(
RXGate(param_rx),
(qubits[0],),
)
)

new_circ.data = new_data

if program_id == "sampler":
return SamplerPub.coerce((new_circ, val_data), pub.shots)
else:
return EstimatorPub.coerce((new_circ, pub.observables, val_data), pub.precision)
90 changes: 83 additions & 7 deletions test/unit/transpiler/passes/basis/test_fold_rzz_angle.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,29 @@
"""Test folding Rzz angle into calibrated range."""

from math import pi
from ddt import ddt, named_data
from itertools import chain
import unittest
import numpy as np
from ddt import ddt, named_data, data, unpack

from qiskit.circuit import QuantumCircuit
from qiskit.circuit.parameter import Parameter
from qiskit.transpiler.passmanager import PassManager
from qiskit.transpiler.preset_passmanagers import generate_preset_pass_manager
from qiskit.quantum_info import Operator
from qiskit.quantum_info import Operator, SparsePauliOp

from qiskit_ibm_runtime.transpiler.passes.basis import FoldRzzAngle
from qiskit_ibm_runtime.transpiler.passes.basis.fold_rzz_angle import (
FoldRzzAngle,
convert_to_rzz_valid_pub,
)
from qiskit_ibm_runtime.fake_provider import FakeFractionalBackend
from qiskit_ibm_runtime.utils.utils import is_valid_rzz_pub
from .....ibm_test_case import IBMTestCase


# pylint: disable=not-context-manager


@ddt
class TestFoldRzzAngle(IBMTestCase):
"""Test FoldRzzAngle pass"""
Expand Down Expand Up @@ -75,9 +84,9 @@ def test_controlflow(self):
"""Test non-ISA Rzz gates inside/outside a control flow branch."""
qc = QuantumCircuit(2, 1)
qc.rzz(-0.2, 0, 1)
with qc.if_test((0, 1)): # pylint: disable=not-context-manager
with qc.if_test((0, 1)):
qc.rzz(-0.1, 0, 1)
with qc.if_test((0, 1)): # pylint: disable=not-context-manager
with qc.if_test((0, 1)):
qc.rzz(-0.3, 0, 1)

pm = PassManager([FoldRzzAngle()])
Expand All @@ -87,11 +96,11 @@ def test_controlflow(self):
expected.x(0)
expected.rzz(0.2, 0, 1)
expected.x(0)
with expected.if_test((0, 1)): # pylint: disable=not-context-manager
with expected.if_test((0, 1)):
expected.x(0)
expected.rzz(0.1, 0, 1)
expected.x(0)
with expected.if_test((0, 1)): # pylint: disable=not-context-manager
with expected.if_test((0, 1)):
expected.x(0)
expected.rzz(0.3, 0, 1)
expected.x(0)
Expand All @@ -115,3 +124,70 @@ def test_fractional_plugin(self):
self.assertEqual(isa_circ.data[0].operation.name, "global_phase")
self.assertEqual(isa_circ.data[1].operation.name, "rzz")
self.assertTrue(np.isclose(isa_circ.data[1].operation.params[0], 7 - 2 * pi))

@data(
[0.2, 0.1, 0.4, 0.3, 2], # no modification in circuit
[0.2, 0.1, 0.3, 0.4, 3], # rzz_2_rx with values 0 and pi
[0.1, 0.2, 0.3, 0.4, 2], # x
[0.2, 0.1, 0.3, 2, 5], # rzz_1_rx, rzz_1_rz, rzz_2_rz with values 0 and pi
[0.3, 2, 0.3, 2, 2], # circuit changes but no new parameters
)
@unpack
def test_rzz_pub_conversion(self, p1_set1, p2_set1, p1_set2, p2_set2, expected_num_params):
"""Test the function `convert_to_rzz_valid_circ_and_vals`"""
p1 = Parameter("p1")
p2 = Parameter("p2")

circ = QuantumCircuit(3)
circ.rzz(p1 + p2, 0, 1)
circ.rzz(0.3, 0, 1)
circ.x(0)
circ.rzz(p1 - p2, 2, 1)

param_vals = [(p1_set1, p2_set1), (p1_set2, p2_set2)]
isa_pub = convert_to_rzz_valid_pub("sampler", (circ, param_vals))

isa_param_vals = isa_pub.parameter_values.ravel().as_array()
num_isa_params = len(isa_param_vals[0])
self.assertEqual(num_isa_params, expected_num_params)

self.assertEqual(is_valid_rzz_pub(isa_pub), "")
for param_set_1, param_set_2 in zip(param_vals, isa_param_vals):
self.assertTrue(
Operator.from_circuit(circ.assign_parameters(param_set_1)).equiv(
Operator.from_circuit(isa_pub.circuit.assign_parameters(param_set_2))
)
)

@unittest.skip("")
def test_rzz_pub_conversion_dynamic(self):
"""Test the function `convert_to_rzz_valid_circ_and_vals` for dynamic circuits"""
p = Parameter("p")
observable = SparsePauliOp("ZZZ")

circ = QuantumCircuit(3, 1)
with circ.if_test((0, 1)):
circ.rzz(p, 1, 2)
circ.rzz(p, 1, 2)
circ.rzz(p, 0, 1)
with circ.if_test((0, 1)):
circ.rzz(p, 1, 0)
circ.rzz(p, 1, 0)
circ.rzz(p, 0, 1)

isa_pub = convert_to_rzz_valid_pub("estimator", (circ, observable, [1, -1]))
self.assertEqual(is_valid_rzz_pub(isa_pub), "")
self.assertEqual([observable], isa_pub.observables)

# TODO: test qubit indices
isa_pub_param_names = np.array(list(chain.from_iterable(isa_pub.parameter_values.data)))
self.assertEqual(len(isa_pub_param_names), 6)
for param_name in [
"rzz_block1_rx1",
"rzz_block1_rx2",
"rzz_rx1",
"rzz_block2_rx1",
"rzz_block2_rx2",
"rzz_rx2",
]:
self.assertIn(param_name, isa_pub_param_names)