Skip to content
Open
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from guppylang_internals.diagnostic import Error
from guppylang_internals.engine import DEF_STORE, ENGINE
from guppylang_internals.error import GuppyError, InternalGuppyError
from guppylang_internals.metadata.debug_info import StringTable
from guppylang_internals.std._internal.compiler.tket_exts import GUPPY_EXTENSION
from guppylang_internals.tys.arg import ConstArg, TypeArg
from guppylang_internals.tys.builtin import nat_type
Expand Down Expand Up @@ -151,16 +152,22 @@ class CompilerContext(ToHugrContext):

checked_globals: Globals

metadata_file_table: StringTable

def __init__(
self,
module: DefinitionBuilder[ops.Module],
file_table: StringTable | None = None,
) -> None:
self.module = module
self.worklist = {}
self.compiled = {}
self.global_funcs = {}
self.checked_globals = Globals(None)
self.current_mono_args = None
self.metadata_file_table = (
file_table if file_table is not None else StringTable([])
)

@contextmanager
def set_monomorphized_args(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,14 @@
import hugr.std.int
import hugr.std.logic
import hugr.std.prelude
from hugr import Wire, ops
from hugr import Node, Wire, ops
from hugr import tys as ht
from hugr import val as hv
from hugr.build import function as hf
from hugr.build.cond_loop import Conditional
from hugr.build.dfg import DP, DfBase

from guppylang_internals.ast_util import AstNode, AstVisitor, get_type
from guppylang_internals.ast_util import AstNode, AstVisitor, get_file, get_type
from guppylang_internals.cfg.builder import tmp_vars
from guppylang_internals.checker.core import Variable, contains_subscript
from guppylang_internals.checker.errors.generic import UnsupportedError
Expand All @@ -28,6 +28,7 @@
GlobalConstId,
)
from guppylang_internals.compiler.hugr_extension import PartialOp
from guppylang_internals.debug_mode import debug_mode_enabled
from guppylang_internals.definition.custom import CustomFunctionDef
from guppylang_internals.definition.value import (
CallableDef,
Expand All @@ -37,6 +38,7 @@
)
from guppylang_internals.engine import ENGINE
from guppylang_internals.error import GuppyError, InternalGuppyError
from guppylang_internals.metadata.debug_info import HugrDebugInfo, make_location_record
from guppylang_internals.nodes import (
AbortExpr,
AbortKind,
Expand Down Expand Up @@ -129,6 +131,12 @@ def compile_row(self, expr: ast.expr, dfg: DFContainer) -> list[Wire]:
"""
return [self.compile(e, dfg) for e in expr_to_row(expr)]

def add_op(
self, op: ops.DataflowOp, /, *args: Wire, ast_node: AstNode | None = None
) -> Node:
"""Adds an op to the builder, with optional debug info."""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is the ast_node the debug info? seems more correct to say "with optional AST node related to the op"

return add_op(self.builder, op, *args, ast_node=ast_node)

@property
def builder(self) -> DfBase[ops.DfParentOp]:
"""The current Hugr dataflow graph builder."""
Expand Down Expand Up @@ -214,7 +222,7 @@ def _if_else(
cond_wire = self.visit(cond)
cond_ty = self.builder.hugr.port_type(cond_wire.out_port())
if cond_ty == OpaqueBool:
cond_wire = self.builder.add_op(read_bool(), cond_wire)
cond_wire = self.add_op(read_bool(), cond_wire)
conditional = self.builder.add_conditional(
cond_wire, *(self.visit(inp) for inp in inputs)
)
Expand Down Expand Up @@ -286,8 +294,8 @@ def visit_GenericParamValue(self, node: GenericParamValue) -> Wire:
load_nat = hugr.std.PRELUDE.get_op("load_nat").instantiate(
[arg], ht.FunctionType([], [ht.USize()])
)
usize = self.builder.add_op(load_nat)
return self.builder.add_op(convert_ifromusize(), usize)
usize = self.add_op(load_nat, ast_node=node)
return self.add_op(convert_ifromusize(), usize, ast_node=node)
case ty:
# Look up monomorphization
match self.ctx.current_mono_args[node.param.idx]:
Expand Down Expand Up @@ -316,12 +324,12 @@ def visit_List(self, node: ast.List) -> Wire:
def _unpack_tuple(self, wire: Wire, types: Sequence[Type]) -> Sequence[Wire]:
"""Add a tuple unpack operation to the graph"""
types = [t.to_hugr(self.ctx) for t in types]
return list(self.builder.add_op(ops.UnpackTuple(types), wire))
return list(self.add_op(ops.UnpackTuple(types), wire))

def _pack_tuple(self, wires: Sequence[Wire], types: Sequence[Type]) -> Wire:
"""Add a tuple pack operation to the graph"""
types = [t.to_hugr(self.ctx) for t in types]
return self.builder.add_op(ops.MakeTuple(types), *wires)
return self.add_op(ops.MakeTuple(types), *wires)

def _pack_returns(self, returns: Sequence[Wire], return_ty: Type) -> Wire:
"""Groups function return values into a tuple"""
Expand Down Expand Up @@ -363,8 +371,8 @@ def visit_LocalCall(self, node: LocalCall) -> Wire:
num_returns = len(type_to_row(func_ty.output))

args = self._compile_call_args(node.args, func_ty)
call = self.builder.add_op(
ops.CallIndirect(func_ty.to_hugr(self.ctx)), func, *args
call = self.add_op(
ops.CallIndirect(func_ty.to_hugr(self.ctx)), func, *args, ast_node=node
)
regular_returns = list(call[:num_returns])
inout_returns = call[num_returns:]
Expand Down Expand Up @@ -420,7 +428,7 @@ def _compile_tensor_with_leftovers(
num_returns = len(type_to_row(func_ty.output))
consumed_args, other_args = args[0:input_len], args[input_len:]
consumed_wires = self._compile_call_args(consumed_args, func_ty)
call = self.builder.add_op(
call = self.add_op(
ops.CallIndirect(func_ty.to_hugr(self.ctx)), func, *consumed_wires
)
regular_returns: list[Wire] = list(call[:num_returns])
Expand Down Expand Up @@ -472,8 +480,11 @@ def visit_PartialApply(self, node: PartialApply) -> Wire:
func_ty.to_hugr(self.ctx),
[get_type(arg).to_hugr(self.ctx) for arg in node.args],
)
return self.builder.add_op(
op, self.visit(node.func), *(self.visit(arg) for arg in node.args)
return self.add_op(
op,
self.visit(node.func),
*(self.visit(arg) for arg in node.args),
ast_node=node,
)

def visit_TypeApply(self, node: TypeApply) -> Wire:
Expand Down Expand Up @@ -503,7 +514,7 @@ def visit_UnaryOp(self, node: ast.UnaryOp) -> Wire:
# since it is not implemented via a dunder method
if isinstance(node.op, ast.Not):
arg = self.visit(node.operand)
return self.builder.add_op(not_op(), arg)
return self.add_op(not_op(), arg, ast_node=node)

raise InternalGuppyError("Node should have been removed during type checking.")

Expand Down Expand Up @@ -561,9 +572,9 @@ def _visit_result_tag(self, tag: Const, loc: ast.expr) -> str:

def visit_AbortExpr(self, node: AbortExpr) -> Wire:
signal = self.visit(node.signal)
signal_usize = self.builder.add_op(convert_itousize(), signal)
signal_usize = self.add_op(convert_itousize(), signal, ast_node=node)
msg = self.visit(node.msg)
err = self.builder.add_op(make_error(), signal_usize, msg)
err = self.add_op(make_error(), signal_usize, msg, ast_node=node)
in_tys = [get_type(e).to_hugr(self.ctx) for e in node.values]
out_tys = [ty.to_hugr(self.ctx) for ty in type_to_row(get_type(node))]
args = [self.visit(e) for e in node.values]
Expand All @@ -572,7 +583,7 @@ def visit_AbortExpr(self, node: AbortExpr) -> Wire:
h_node = build_panic(self.builder, in_tys, out_tys, err, *args)
case AbortKind.ExitShot:
op = panic(in_tys, out_tys, AbortKind.ExitShot)
h_node = self.builder.add_op(op, err, *args)
h_node = self.add_op(op, err, *args, ast_node=node)
return self._pack_returns(list(h_node.outputs()), get_type(node))

def visit_BarrierExpr(self, node: BarrierExpr) -> Wire:
Expand All @@ -582,7 +593,7 @@ def visit_BarrierExpr(self, node: BarrierExpr) -> Wire:
ht.FunctionType.endo(hugr_tys),
)

barrier_n = self.builder.add_op(op, *(self.visit(e) for e in node.args))
barrier_n = self.add_op(op, *(self.visit(e) for e in node.args), ast_node=node)

self._update_inout_ports(node.args, iter(barrier_n), node.func_ty)
return self._pack_returns([], NoneType())
Expand All @@ -605,27 +616,37 @@ def visit_StateResultExpr(self, node: StateResultExpr) -> Wire:
if not node.array_len:
# If the input is a sequence of qubits, we pack them into an array.
qubits_in = [self.visit(e) for e in node.args[1:]]
qubit_arr_in = self.builder.add_op(
array_new(ht.Qubit, len(node.args) - 1), *qubits_in
qubit_arr_in = self.add_op(
array_new(ht.Qubit, len(node.args) - 1), *qubits_in, ast_node=node
)
# Turn into standard array from borrow array.
qubit_arr_in = self.builder.add_op(
array_to_std_array(ht.Qubit, num_qubits_arg), qubit_arr_in
qubit_arr_in = self.add_op(
array_to_std_array(ht.Qubit, num_qubits_arg),
qubit_arr_in,
ast_node=node,
)

qubit_arr_out = self.builder.add_op(op, qubit_arr_in)
qubit_arr_out = self.add_op(op, qubit_arr_in, ast_node=node)

qubit_arr_out = self.builder.add_op(
std_array_to_array(ht.Qubit, num_qubits_arg), qubit_arr_out
qubit_arr_out = self.add_op(
std_array_to_array(ht.Qubit, num_qubits_arg),
qubit_arr_out,
ast_node=node,
)
qubits_out = unpack_array(self.builder, qubit_arr_out)
qubits_out = unpack_array(self.builder, qubit_arr_out, ast_node=node)
else:
# If the input is an array of qubits, we need to convert to a standard
# array.
qubits_in = [self.visit(node.args[1])]
qubits_out = [
apply_array_op_with_conversions(
self.ctx, self.builder, op, ht.Qubit, num_qubits_arg, qubits_in[0]
self.ctx,
self.builder,
op,
ht.Qubit,
num_qubits_arg,
qubits_in[0],
ast_node=node,
)
]

Expand Down Expand Up @@ -655,16 +676,18 @@ def visit_DesugaredArrayComp(self, node: DesugaredArrayComp) -> Wire:
count_var = Variable(next(tmp_vars), int_type(), node)
hugr_elt_ty = node.elt_ty.to_hugr(self.ctx)
# Initialise empty array.
self.dfg[array_var] = self.builder.add_op(
barray_new_all_borrowed(hugr_elt_ty, node.length.to_arg().to_hugr(self.ctx))
self.dfg[array_var] = self.add_op(
barray_new_all_borrowed(
hugr_elt_ty, node.length.to_arg().to_hugr(self.ctx)
),
)
self.dfg[count_var] = self.builder.load(
hugr.std.int.IntVal(0, width=NumericType.INT_WIDTH)
)
with self._build_generators([node.generator], [array_var, count_var]):
elt = self.visit(node.elt)
array, count = self.dfg[array_var], self.dfg[count_var]
idx = self.builder.add_op(convert_itousize(), count)
idx = self.add_op(convert_itousize(), count)
self.dfg[array_var] = self.builder.add_op(
barray_return(hugr_elt_ty, node.length.to_arg().to_hugr(self.ctx)),
array,
Expand Down Expand Up @@ -748,6 +771,23 @@ def visit_Compare(self, node: ast.Compare) -> Wire:
raise InternalGuppyError("Node should have been removed during type checking.")


P = TypeVar("P", bound=ops.DfParentOp)


def add_op(
builder: DfBase[P],
op: ops.DataflowOp,
/,
*args: Wire,
ast_node: AstNode | None = None,
) -> Node:
"""Adds an op to the builder, with optional debug info."""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comment as above about "debug info"

op_node = builder.add_op(op, *args)
if debug_mode_enabled() and ast_node is not None and get_file(ast_node) is not None:
op_node.metadata[HugrDebugInfo] = make_location_record(ast_node)
return op_node


def expr_to_row(expr: ast.expr) -> list[ast.expr]:
"""Turns an expression into a row expressions by unpacking top-level tuples."""
return expr.elts if isinstance(expr, ast.Tuple) else [expr]
Expand All @@ -758,27 +798,36 @@ def pack_returns(
return_ty: Type,
builder: DfBase[ops.DfParentOp],
ctx: CompilerContext,
ast_node: AstNode | None = None,
) -> Wire:
"""Groups function return values into a tuple"""
if isinstance(return_ty, TupleType | NoneType) and not return_ty.preserve:
types = type_to_row(return_ty)
assert len(returns) == len(types)
hugr_tys = [t.to_hugr(ctx) for t in types]
return builder.add_op(ops.MakeTuple(hugr_tys), *returns)
return add_op(builder, ops.MakeTuple(hugr_tys), *returns, ast_node=ast_node)
assert len(returns) == 1, (
f"Expected a single return value. Got {returns}. return type {return_ty}"
)
return returns[0]


def unpack_wire(
wire: Wire, return_ty: Type, builder: DfBase[ops.DfParentOp], ctx: CompilerContext
wire: Wire,
return_ty: Type,
builder: DfBase[ops.DfParentOp],
ctx: CompilerContext,
ast_node: AstNode | None = None,
) -> list[Wire]:
"""The inverse of `pack_returns`"""
if isinstance(return_ty, TupleType | NoneType) and not return_ty.preserve:
types = type_to_row(return_ty)
hugr_tys = [t.to_hugr(ctx) for t in types]
return list(builder.add_op(ops.UnpackTuple(hugr_tys), wire).outputs())
return list(
add_op(
builder, ops.UnpackTuple(hugr_tys), wire, ast_node=ast_node
).outputs()
)
return [wire]


Expand Down Expand Up @@ -885,6 +934,7 @@ def apply_array_op_with_conversions(
size_arg: ht.TypeArg,
input_array: Wire,
convert_bool: bool = False,
ast_node: AstNode | None = None,
) -> Wire:
"""Applies common transformations to a Guppy array input before it can be passed to
a Hugr op operating on a standard Hugr array, and then reverses them again on the
Expand All @@ -898,20 +948,28 @@ def apply_array_op_with_conversions(
array_read = array_read_bool(ctx)
array_read = builder.load_function(array_read)
map_op = array_map(OpaqueBool, size_arg, ht.Bool)
input_array = builder.add_op(map_op, input_array, array_read)
input_array = add_op(
builder, map_op, input_array, array_read, ast_node=ast_node
)
elem_ty = ht.Bool

input_array = builder.add_op(array_to_std_array(elem_ty, size_arg), input_array)
input_array = add_op(
builder, array_to_std_array(elem_ty, size_arg), input_array, ast_node=ast_node
)

result_array = builder.add_op(op, input_array)
result_array = add_op(builder, op, input_array, ast_node=ast_node)

result_array = builder.add_op(std_array_to_array(elem_ty, size_arg), result_array)
result_array = add_op(
builder, std_array_to_array(elem_ty, size_arg), result_array, ast_node=ast_node
)

if convert_bool:
array_make_opaque = array_make_opaque_bool(ctx)
array_make_opaque = builder.load_function(array_make_opaque)
map_op = array_map(ht.Bool, size_arg, OpaqueBool)
result_array = builder.add_op(map_op, result_array, array_make_opaque)
result_array = add_op(
builder, map_op, result_array, array_make_opaque, ast_node=ast_node
)
elem_ty = OpaqueBool

return result_array
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from guppylang_internals.compiler.cfg_compiler import compile_cfg
from guppylang_internals.compiler.core import CompilerContext, DFContainer
from guppylang_internals.compiler.expr_compiler import ExprCompiler
from guppylang_internals.definition.metadata import add_metadata
from guppylang_internals.metadata.common import add_metadata
from guppylang_internals.nodes import CheckedModifiedBlock, PlaceNode
from guppylang_internals.std._internal.compiler.array import (
array_new,
Expand Down
18 changes: 18 additions & 0 deletions guppylang-internals/src/guppylang_internals/debug_mode.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
"""Global state for determining whether to attach debug information to Hugr nodes
during compilation."""

DEBUG_MODE_ENABLED = False
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this be private?

Suggested change
DEBUG_MODE_ENABLED = False
_DEBUG_MODE_ENABLED = False



def turn_on_debug_mode() -> None:
global DEBUG_MODE_ENABLED
DEBUG_MODE_ENABLED = True


def turn_off_debug_mode() -> None:
global DEBUG_MODE_ENABLED
DEBUG_MODE_ENABLED = False


def debug_mode_enabled() -> bool:
return DEBUG_MODE_ENABLED
Loading
Loading