Skip to content

Commit bc06c61

Browse files
committed
refactoring: make CircuitTransformer into an abstract interface (quantumlib#142)
* refactor out an abstract interface for circuit transformers and allow constructing QuantumBoard with an optional custom transformer. * fix pytket cirq extension import after new version of pytket
1 parent 62abb63 commit bc06c61

File tree

4 files changed

+46
-31
lines changed

4 files changed

+46
-31
lines changed

Diff for: recirq/qaoa/placement.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import networkx as nx
77
import numpy as np
88
import pytket
9-
import pytket.cirq
9+
import pytket.extensions.cirq
1010
from pytket.circuit import Node, Qubit
1111
from pytket.passes import SequencePass, RoutingPass, PlacementPass
1212
from pytket.predicates import CompilationUnit, ConnectivityPredicate
@@ -87,7 +87,7 @@ def place_on_device(circuit: cirq.Circuit,
8787
initial_map: Initial placement of qubits
8888
final_map: The final placement of qubits after action of the circuit
8989
"""
90-
tk_circuit = pytket.cirq.cirq_to_tk(circuit)
90+
tk_circuit = pytket.extensions.cirq.cirq_to_tk(circuit)
9191
tk_device = _device_to_tket_device(device)
9292

9393
unit = CompilationUnit(tk_circuit, [ConnectivityPredicate(tk_device)])
@@ -103,7 +103,7 @@ def place_on_device(circuit: cirq.Circuit,
103103
for n1, n2 in unit.initial_map.items()}
104104
final_map = {tk_to_cirq_qubit(n1): tk_to_cirq_qubit(n2)
105105
for n1, n2 in unit.final_map.items()}
106-
routed_circuit = pytket.cirq.tk_to_cirq(unit.circuit)
106+
routed_circuit = pytket.extensions.cirq.tk_to_cirq(unit.circuit)
107107

108108
return routed_circuit, initial_map, final_map
109109

Diff for: recirq/quantum_chess/circuit_transformer.py

+20-8
Original file line numberDiff line numberDiff line change
@@ -22,19 +22,31 @@
2222

2323

2424
class CircuitTransformer:
25+
"""Abstract interface for circuit transformations.
26+
27+
For example: NamedQubit -> GridQubit transformations.
28+
"""
29+
def __init__(self):
30+
pass
31+
32+
def transform(self, circuit: cirq.Circuit) -> cirq.Circuit:
33+
"""Applies the transformation to the circuit."""
34+
return None
35+
36+
37+
class ConnectivityHeuristicCircuitTransformer(CircuitTransformer):
2538
"""Optimizer that will transform a circuit using NamedQubits
2639
and transform it to use GridQubits. This will use a breadth-first
2740
search to find a suitable mapping into the specified grid.
2841
2942
It will then transform all operations to use the new qubits.
3043
"""
31-
3244
def __init__(self, device: cirq.Device):
45+
super().__init__()
3346
self.device = device
3447
self.mapping = None
3548
self.starting_qubit = self.find_start_qubit(device.qubits)
3649
self.qubit_list = device.qubits
37-
super().__init__()
3850

3951
def qubits_within(self, depth: int, qubit: cirq.GridQubit,
4052
qubit_list: Iterable[cirq.GridQubit]) -> int:
@@ -54,7 +66,8 @@ def qubits_within(self, depth: int, qubit: cirq.GridQubit,
5466
c += self.qubits_within(depth - 1, qubit + diff, qubit_list)
5567
return c
5668

57-
def find_start_qubit(self, qubit_list: List[cirq.Qid],
69+
def find_start_qubit(self,
70+
qubit_list: List[cirq.Qid],
5871
depth=3) -> Optional[cirq.GridQubit]:
5972
"""Finds a reasonable starting qubit to start the mapping.
6073
@@ -278,7 +291,7 @@ def qubit_mapping(self,
278291
self.mapping = mapping
279292
return mapping
280293

281-
def optimize_circuit(self, circuit: cirq.Circuit) -> cirq.Circuit:
294+
def transform(self, circuit: cirq.Circuit) -> cirq.Circuit:
282295
""" Creates a new qubit mapping for a circuit and transforms it.
283296
284297
This uses `qubit_mapping` to create a mapping from the qubits
@@ -299,10 +312,9 @@ class SycamoreDecomposer(cirq.PointOptimizer):
299312
Currently supported are controlled ISWAPs with a single control
300313
and control-X gates with multiple controls (TOFFOLI gates).:w
301314
"""
302-
303-
def optimization_at(self, circuit: cirq.Circuit, index: int,
304-
op: cirq.Operation
305-
) -> Optional[cirq.PointOptimizationSummary]:
315+
def optimization_at(
316+
self, circuit: cirq.Circuit, index: int,
317+
op: cirq.Operation) -> Optional[cirq.PointOptimizationSummary]:
306318
if len(op.qubits) > 3:
307319
raise ValueError(f'Four qubit ops not yet supported: {op}')
308320
new_ops = None

Diff for: recirq/quantum_chess/circuit_transformer_test.py

+13-13
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# Copyright 2020 Google
2-
#
2+
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
55
# You may obtain a copy of the License at
@@ -33,59 +33,59 @@
3333
@pytest.mark.parametrize('device',
3434
(cirq.google.Sycamore23, cirq.google.Sycamore))
3535
def test_single_qubit_ops(device):
36-
transformer = ct.CircuitTransformer(device)
36+
transformer = ct.ConnectivityHeuristicCircuitTransformer(device)
3737
c = cirq.Circuit(cirq.X(a1), cirq.X(a2), cirq.X(a3))
3838
transformer.qubit_mapping(c)
39-
c = transformer.optimize_circuit(c)
39+
c = transformer.transform(c)
4040
device.validate_circuit(c)
4141

4242

4343
@pytest.mark.parametrize('device',
4444
(cirq.google.Sycamore23, cirq.google.Sycamore))
4545
def test_single_qubit_with_two_qubits(device):
46-
transformer = ct.CircuitTransformer(device)
46+
transformer = ct.ConnectivityHeuristicCircuitTransformer(device)
4747
c = cirq.Circuit(cirq.X(a1), cirq.X(a2), cirq.X(a3),
4848
cirq.ISWAP(a3, a4) ** 0.5)
4949
transformer.qubit_mapping(c)
50-
device.validate_circuit(transformer.optimize_circuit(c))
50+
device.validate_circuit(transformer.transform(c))
5151

5252

5353
@pytest.mark.parametrize('device',
5454
(cirq.google.Sycamore23, cirq.google.Sycamore))
5555
def test_three_split_moves(device):
56-
transformer = ct.CircuitTransformer(device)
56+
transformer = ct.ConnectivityHeuristicCircuitTransformer(device)
5757
c = cirq.Circuit(qm.split_move(a1, a2, b1), qm.split_move(a2, a3, b3),
5858
qm.split_move(b1, c1, c2))
5959
transformer.qubit_mapping(c)
60-
device.validate_circuit(transformer.optimize_circuit(c))
60+
device.validate_circuit(transformer.transform(c))
6161

6262

6363
@pytest.mark.parametrize('device',
6464
(cirq.google.Sycamore23, cirq.google.Sycamore))
6565
def test_disconnected(device):
66-
transformer = ct.CircuitTransformer(device)
66+
transformer = ct.ConnectivityHeuristicCircuitTransformer(device)
6767
c = cirq.Circuit(qm.split_move(a1, a2, a3), qm.split_move(a3, a4, d1),
6868
qm.split_move(b1, b2, b3), qm.split_move(c1, c2, c3))
6969
transformer.qubit_mapping(c)
70-
device.validate_circuit(transformer.optimize_circuit(c))
70+
device.validate_circuit(transformer.transform(c))
7171

7272

7373
@pytest.mark.parametrize('device',
7474
(cirq.google.Sycamore23, cirq.google.Sycamore))
7575
def test_move_around_square(device):
76-
transformer = ct.CircuitTransformer(device)
76+
transformer = ct.ConnectivityHeuristicCircuitTransformer(device)
7777
c = cirq.Circuit(qm.normal_move(a1, a2), qm.normal_move(a2, b2),
7878
qm.normal_move(b2, b1), qm.normal_move(b1, a1))
7979
transformer.qubit_mapping(c)
80-
device.validate_circuit(transformer.optimize_circuit(c))
80+
device.validate_circuit(transformer.transform(c))
8181

8282

8383
@pytest.mark.parametrize('device',
8484
(cirq.google.Sycamore23, cirq.google.Sycamore))
8585
def test_split_then_merge(device):
86-
transformer = ct.CircuitTransformer(device)
86+
transformer = ct.ConnectivityHeuristicCircuitTransformer(device)
8787
c = cirq.Circuit(qm.split_move(a1, a2, b1), qm.split_move(a2, a3, b3),
8888
qm.split_move(b1, c1, c2), qm.normal_move(c1, d1),
8989
qm.normal_move(a3, a4), qm.merge_move(a4, d1, a1))
9090
transformer.qubit_mapping(c)
91-
device.validate_circuit(transformer.optimize_circuit(c))
91+
device.validate_circuit(transformer.transform(c))

Diff for: recirq/quantum_chess/quantum_board.py

+10-7
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
square_to_bit,
2525
xy_to_bit,
2626
)
27-
import recirq.quantum_chess.circuit_transformer as circuit_transformer
27+
import recirq.quantum_chess.circuit_transformer as ct
2828
import recirq.quantum_chess.enums as enums
2929
import recirq.quantum_chess.move as move
3030
import recirq.quantum_chess.quantum_moves as qm
@@ -60,6 +60,8 @@ class CirqBoard:
6060
an error or post-selects the result away.
6161
noise_mitigation: Threshold of samples to overcome in order
6262
to be considered not noise.
63+
transformer: The CircuitTransformer to use to convert the board's
64+
NamedQubit circuit into a GridQubit circuit.
6365
"""
6466

6567
def __init__(self,
@@ -68,11 +70,14 @@ def __init__(self,
6870
device: Optional[cirq.Device] = None,
6971
error_mitigation: Optional[
7072
enums.ErrorMitigation] = enums.ErrorMitigation.Nothing,
71-
noise_mitigation: Optional[float] = 0.0):
73+
noise_mitigation: Optional[float] = 0.0,
74+
transformer: Optional[ct.CircuitTransformer] = None):
7275
self.device = device
7376
self.sampler = sampler
7477
if device is not None:
75-
self.transformer = circuit_transformer.CircuitTransformer(device)
78+
self.transformer = (
79+
transformer
80+
or ct.ConnectivityHeuristicCircuitTransformer(device))
7681
self.with_state(init_basis_state)
7782
self.error_mitigation = error_mitigation
7883
self.noise_mitigation = noise_mitigation
@@ -179,11 +184,9 @@ def sample_with_ancilla(self, num_samples: int
179184
# Translate circuit to grid qubits and sqrtISWAP gates
180185
if self.device is not None:
181186
# Decompose 3-qubit operations
182-
circuit_transformer.SycamoreDecomposer().optimize_circuit(
183-
measure_circuit)
187+
ct.SycamoreDecomposer().optimize_circuit(measure_circuit)
184188
# Create NamedQubit to GridQubit mapping and transform
185-
measure_circuit = self.transformer.optimize_circuit(
186-
measure_circuit)
189+
measure_circuit = self.transformer.transform(measure_circuit)
187190

188191
# For debug, ensure that the circuit correctly validates
189192
self.device.validate_circuit(measure_circuit)

0 commit comments

Comments
 (0)