Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

QasmModule Circuit Drawer #122

Open
wants to merge 31 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
7fac964
basic one qubit drawer
arulandu Jan 11, 2025
40688bf
controlled gates + swap gate
arulandu Jan 11, 2025
197f309
.draw() on module
arulandu Jan 11, 2025
cd6d5eb
measurement
arulandu Jan 11, 2025
c79b46d
moment based drawing
arulandu Jan 11, 2025
cbd9293
pad label
arulandu Jan 11, 2025
9baed7b
multi control
arulandu Jan 14, 2025
0ff03af
custom test
arulandu Jan 14, 2025
56ae343
fix double draw
arulandu Jan 16, 2025
dbc5f3e
angle font
arulandu Jan 17, 2025
d9c27f8
idle wires option
arulandu Jan 17, 2025
0076925
add barrier support
arulandu Jan 17, 2025
77d6897
group cregs + measure style
arulandu Jan 17, 2025
5d61b1a
sections + idle wires order
arulandu Jan 17, 2025
ada2513
overflow wrap + idle wires order fix
arulandu Jan 17, 2025
4a7848c
linting
arulandu Jan 17, 2025
a66a3fa
merge main
arulandu Jan 17, 2025
5575332
fix kwargs
arulandu Jan 17, 2025
f6d4d77
phase placeholder
arulandu Jan 17, 2025
11eeb84
global phase
arulandu Jan 20, 2025
8170919
Merge branch 'main' into draw
TheGupta2012 Jan 28, 2025
8d8bfec
reset + labels + linting
arulandu Feb 14, 2025
871750a
Merge branch 'main' into draw
ryanhill1 Feb 14, 2025
6750d22
build, static type checking, user interface
ryanhill1 Feb 14, 2025
f8e06c8
Merge branch 'qBraid:main' into draw
arulandu Feb 18, 2025
c2cb5de
standalone measurement
arulandu Feb 18, 2025
0528696
add tests
ryanhill1 Feb 20, 2025
02ca6cb
add mpl to test extra
ryanhill1 Feb 20, 2025
705cc48
try pytest-mpl
ryanhill1 Feb 20, 2025
d5ae252
clean up tests
ryanhill1 Feb 20, 2025
d20d8ea
Merge branch 'main' into draw
ryanhill1 Feb 21, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ classifiers = [
"Operating System :: Unix",
"Operating System :: MacOS",
]

dependencies = ["numpy", "openqasm3[parser]>=1.0.0,<2.0.0"]

[project.urls]
Expand All @@ -43,6 +44,7 @@ cli = ["typer>=0.12.1", "rich>=10.11.0", "typing-extensions"]
test = ["pytest", "pytest-cov"]
lint = ["black", "isort", "pylint", "mypy", "qbraid-cli>=0.8.5"]
docs = ["sphinx>=7.3.7,<8.2.0", "sphinx-autodoc-typehints>=1.24,<2.6", "sphinx-rtd-theme>=2.0.0,<4.0.0", "docutils<0.22", "sphinx-copybutton"]
visualization = ["matplotlib"]

[tool.setuptools_scm]
write_to = "src/pyqasm/_version.py"
Expand Down
15 changes: 15 additions & 0 deletions src/pyqasm/maps.py
Original file line number Diff line number Diff line change
Expand Up @@ -1236,6 +1236,21 @@ def map_qasm_inv_op_to_callable(op_name: str):
raise ValidationError(f"Unsupported / undeclared QASM operation: {op_name}")


REV_CTRL_GATE_MAP = {
"cx": "x",
"cy": "y",
"cz": "z",
"crx": "rx",
"cry": "ry",
"crz": "rz",
"cp": "p",
"ch": "h",
"cu": "u",
"cswap": "swap",
"ccx": "cx",
}


# pylint: disable=inconsistent-return-statements
def qasm_variable_type_cast(openqasm_type, var_name, base_size, rhs_value):
"""Cast the variable type to the type to match, if possible.
Expand Down
4 changes: 4 additions & 0 deletions src/pyqasm/modules/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,3 +551,7 @@ def accept(self, visitor):
Args:
visitor (QasmVisitor): The visitor to accept
"""

@abstractmethod
def draw(self):
"""Draw the module"""
5 changes: 5 additions & 0 deletions src/pyqasm/modules/qasm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from pyqasm.exceptions import ValidationError
from pyqasm.modules.base import QasmModule
from pyqasm.modules.qasm3 import Qasm3Module
from pyqasm.printer import draw


class Qasm2Module(QasmModule):
Expand Down Expand Up @@ -105,3 +106,7 @@ def accept(self, visitor):
final_stmt_list = visitor.finalize(unrolled_stmt_list)

self.unrolled_ast.statements = final_stmt_list

def draw(self):
"""Draw the module"""
return draw(self.to_qasm3())
5 changes: 5 additions & 0 deletions src/pyqasm/modules/qasm3.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from openqasm3.printer import dumps

from pyqasm.modules.base import QasmModule
from pyqasm.printer import draw


class Qasm3Module(QasmModule):
Expand Down Expand Up @@ -48,3 +49,7 @@ def accept(self, visitor):
final_stmt_list = visitor.finalize(unrolled_stmt_list)

self._unrolled_ast.statements = final_stmt_list

def draw(self):
"""Draw the module"""
return draw(self)
306 changes: 306 additions & 0 deletions src/pyqasm/printer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,306 @@
# Copyright (C) 2024 qBraid
#
# This file is part of pyqasm
#
# Pyqasm is free software released under the GNU General Public License v3
# or later. You can redistribute and/or modify it under the terms of the GPL v3.
# See the LICENSE file in the project root or <https://www.gnu.org/licenses/gpl-3.0.html>.
#
# THERE IS NO WARRANTY for pyqasm, as per Section 15 of the GPL v3.

"""
Module with analysis functions for QASM visitor

"""
from __future__ import annotations

from typing import TYPE_CHECKING, Any, Optional, Union

import openqasm3.ast as ast

from pyqasm.expressions import Qasm3ExprEvaluator
from pyqasm.maps import (
FIVE_QUBIT_OP_MAP,
FOUR_QUBIT_OP_MAP,
ONE_QUBIT_OP_MAP,
ONE_QUBIT_ROTATION_MAP,
REV_CTRL_GATE_MAP,
THREE_QUBIT_OP_MAP,
TWO_QUBIT_OP_MAP,
)

try:
from matplotlib import pyplot as plt
mpl_installed = True
except ImportError as e:
mpl_installed = False

if TYPE_CHECKING:
from pyqasm.modules.base import Qasm3Module



DEFAULT_GATE_COLOR = "#d4b6e8"
HADAMARD_GATE_COLOR = "#f0a6a6"

GATE_BOX_WIDTH, GATE_BOX_HEIGHT = 0.6, 0.6
GATE_SPACING = 0.2
LINE_SPACING = 0.6
TEXT_MARGIN = 0.6
FRAME_PADDING = 0.2


def draw(module: Qasm3Module, output="mpl"):
if not mpl_installed:
raise ImportError("matplotlib needs to be installed prior to running pyqasm.draw(). You can install matplotlib with:\n'pip install pyqasm[visualization]'")

if output == "mpl":
return _draw_mpl(module)
else:
raise NotImplementedError(f"{output} drawing for Qasm3Module is unsupported")


def _draw_mpl(module: Qasm3Module) -> plt.Figure:
module.unroll()
module.remove_includes()
module.remove_barriers()

n_lines = module._num_qubits + module._num_clbits
statements = module._statements

# compute line numbers per qubit + max depth
line_nums = dict()
line_num = -1
max_depth = 0

for clbit_reg in module._classical_registers.keys():
size = module._classical_registers[clbit_reg]
line_num += size
for i in range(size):
line_nums[(clbit_reg, i)] = line_num
line_num -= 1
line_num += size

for qubit_reg in module._qubit_registers.keys():
size = module._qubit_registers[qubit_reg]
line_num += size
for i in range(size):
line_nums[(qubit_reg, i)] = line_num
depth = module._qubit_depths[(qubit_reg, i)]._total_ops()
max_depth = max(max_depth, depth)
line_num -= 1
line_num += size

# compute moments
depths = dict()
for k in line_nums.keys():
depths[k] = -1

moments = []
for statement in statements:
if "Declaration" in str(type(statement)):
continue
if isinstance(statement, ast.QuantumGate):
qubits = [_identifier_to_key(q) for q in statement.qubits]
depth = 1 + max([depths[q] for q in qubits])
for q in qubits:
depths[q] = depth
elif isinstance(statement, ast.QuantumMeasurementStatement):
qubit_key = _identifier_to_key(statement.measure.qubit)
target_key = _identifier_to_key(statement.target)
depth = 1 + max(depths[qubit_key], depths[target_key])
for k in [qubit_key, target_key]:
depths[k] = depth
elif isinstance(statement, ast.QuantumBarrier):
pass
elif isinstance(statement, ast.QuantumReset):
pass
else:
raise NotImplementedError(f"Unsupported statement: {statement}")

if depth >= len(moments):
moments.append([])
moments[depth].append(statement)

width = 0
for moment in moments:
width += _mpl_get_moment_width(moment)
width += TEXT_MARGIN

fig, ax = plt.subplots(
figsize=(width, n_lines * GATE_BOX_HEIGHT + LINE_SPACING * (n_lines - 1))
)
ax.set_ylim(
-GATE_BOX_HEIGHT / 2 - FRAME_PADDING / 2,
n_lines * GATE_BOX_HEIGHT
+ LINE_SPACING * (n_lines - 1)
- GATE_BOX_HEIGHT / 2
+ FRAME_PADDING / 2,
)
ax.set_xlim(-FRAME_PADDING / 2, width)
ax.axis("off")
# ax.set_aspect('equal')
# plt.tight_layout()

x = 0
for k in module._qubit_registers.keys():
for i in range(module._qubit_registers[k]):
line_num = line_nums[(k, i)]
_mpl_draw_qubit_label((k, i), line_num, ax, x)
for k in module._classical_registers.keys():
for i in range(module._classical_registers[k]):
line_num = line_nums[(k, i)]
_mpl_draw_clbit_label((k, i), line_num, ax, x)
x += TEXT_MARGIN
x0 = x
for moment in moments:
dx = _mpl_get_moment_width(moment)
_mpl_draw_lines(dx, line_nums, ax, x)
x += dx
x = x0
for moment in moments:
dx = _mpl_get_moment_width(moment)
for statement in moment:
_mpl_draw_statement(statement, line_nums, ax, x)
x += dx

return fig


def _identifier_to_key(identifier: ast.Identifier | ast.IndexedIdentifier) -> tuple[str, int]:
if isinstance(identifier, ast.Identifier):
return identifier.name, -1
else:
return (
identifier.name.name,
Qasm3ExprEvaluator.evaluate_expression(identifier.indices[0][0])[0],
)


def _mpl_line_to_y(line_num: int) -> float:
return line_num * (GATE_BOX_HEIGHT + LINE_SPACING)


def _mpl_draw_qubit_label(qubit: tuple[str, int], line_num: int, ax: plt.Axes, x: float):
ax.text(x, _mpl_line_to_y(line_num), f"{qubit[0]}[{qubit[1]}]", ha="right", va="center")


def _mpl_draw_clbit_label(clbit: tuple[str, int], line_num: int, ax: plt.Axes, x: float):
ax.text(x, _mpl_line_to_y(line_num), f"{clbit[0]}[{clbit[1]}]", ha="right", va="center")


def _mpl_draw_lines(width, line_nums: dict[tuple[str, int], int], ax: plt.Axes, x: float):
for k in line_nums.keys():
y = _mpl_line_to_y(line_nums[k])
ax.hlines(
xmin=x - width / 2, xmax=x + width / 2, y=y, color="black", linestyle="-", zorder=-10
)


def _mpl_get_moment_width(moment: list[ast.QuantumStatement]) -> float:
return max([_mpl_get_statement_width(s) for s in moment])


def _mpl_get_statement_width(statement: ast.QuantumStatement) -> float:
return GATE_BOX_WIDTH + GATE_SPACING


def _mpl_draw_statement(
statement: ast.QuantumStatement, line_nums: dict[tuple[str, int], int], ax: plt.Axes, x: float
):
if isinstance(statement, ast.QuantumGate):
args = [Qasm3ExprEvaluator.evaluate_expression(arg)[0] for arg in statement.arguments]
lines = [line_nums[_identifier_to_key(q)] for q in statement.qubits]
_mpl_draw_gate(statement, args, lines, ax, x)
elif isinstance(statement, ast.QuantumMeasurementStatement):
qubit_key = _identifier_to_key(statement.measure.qubit)
target_key = _identifier_to_key(statement.target)
_mpl_draw_measurement(line_nums[qubit_key], line_nums[target_key], ax, x)
else:
raise NotImplementedError(f"Unsupported statement: {statement}")


def _mpl_draw_gate(
gate: ast.QuantumGate, args: list[Any], lines: list[int], ax: plt.Axes, x: float
):
name = gate.name.name
if name in REV_CTRL_GATE_MAP:
i = 0
while name in REV_CTRL_GATE_MAP:
name = REV_CTRL_GATE_MAP[name]
_draw_mpl_control(lines[i], lines[-1], ax, x)
i += 1
lines = lines[i:]
gate.name.name = name

if name in ONE_QUBIT_OP_MAP or name in ONE_QUBIT_ROTATION_MAP:
_draw_mpl_one_qubit_gate(gate, args, lines[0], ax, x)
elif name in TWO_QUBIT_OP_MAP:
if name == "swap":
_draw_mpl_swap(lines[0], lines[1], ax, x)
else:
raise NotImplementedError(f"Unsupported gate: {name}")
else:
raise NotImplementedError(f"Unsupported gate: {name}")


# TODO: switch to moment based system. go progressively, calculating required width for each moment, center the rest. this makes position calculations not to bad. if we overflow, start a new figure.


def _draw_mpl_one_qubit_gate(
gate: ast.QuantumGate, args: list[Any], line: int, ax: plt.Axes, x: float
):
color = DEFAULT_GATE_COLOR
if gate.name.name == "h":
color = HADAMARD_GATE_COLOR
text = gate.name.name.upper()
if len(args) > 0:
text += f"\n({', '.join([f'{a:.3f}' if isinstance(a, float) else str(a) for a in args])})"

y = _mpl_line_to_y(line)
rect = plt.Rectangle(
(x - GATE_BOX_WIDTH / 2, y - GATE_BOX_HEIGHT / 2),
GATE_BOX_WIDTH,
GATE_BOX_HEIGHT,
facecolor=color,
edgecolor="none",
)
ax.add_patch(rect)
ax.text(x, y, text, ha="center", va="center")


def _draw_mpl_control(ctrl_line: int, target_line: int, ax: plt.Axes, x: float):
y1 = _mpl_line_to_y(ctrl_line)
y2 = _mpl_line_to_y(target_line)
ax.vlines(x=x, ymin=min(y1, y2), ymax=max(y1, y2), color="black", linestyle="-", zorder=-1)
ax.plot(x, y1, "ko", markersize=8, markerfacecolor="black")


def _draw_mpl_swap(line1: int, line2: int, ax: plt.Axes, x: float):
y1 = _mpl_line_to_y(line1)
y2 = _mpl_line_to_y(line2)
ax.vlines(x=x, ymin=min(y1, y2), ymax=max(y1, y2), color="black", linestyle="-")
ax.plot(x, y1, "x", markersize=8, color="black")
ax.plot(x, y2, "x", markersize=8, color="black")


def _mpl_draw_measurement(qbit_line: int, cbit_line: int, ax: plt.Axes, x: float):
y1 = _mpl_line_to_y(qbit_line)
y2 = _mpl_line_to_y(cbit_line)

rect = plt.Rectangle(
(x - GATE_BOX_WIDTH / 2, y1 - GATE_BOX_HEIGHT / 2),
GATE_BOX_WIDTH,
GATE_BOX_HEIGHT,
facecolor="gray",
edgecolor="none",
)
ax.add_patch(rect)
ax.text(x, y1, "M", ha="center", va="center")
ax.vlines(
x=x - 0.025, ymin=min(y1, y2), ymax=max(y1, y2), color="gray", linestyle="-", zorder=-1
)
ax.vlines(
x=x + 0.025, ymin=min(y1, y2), ymax=max(y1, y2), color="gray", linestyle="-", zorder=-1
)
ax.plot(x, y2 + 0.1, "v", markersize=16, color="gray")
Loading