Skip to content
20 changes: 18 additions & 2 deletions guppylang-internals/src/guppylang_internals/cfg/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
Power,
)
from guppylang_internals.span import Span, to_span
from guppylang_internals.tys.ty import NoneType
from guppylang_internals.tys.ty import NoneType, UnitaryFlags

# In order to build expressions, need an endless stream of unique temporary variables
# to store intermediate results
Expand Down Expand Up @@ -78,14 +78,21 @@ class CFGBuilder(AstVisitor[BB | None]):
cfg: CFG
globals: Globals

def build(self, nodes: list[ast.stmt], returns_none: bool, globals: Globals) -> CFG:
def build(
self,
nodes: list[ast.stmt],
returns_none: bool,
globals: Globals,
unitary_flags: UnitaryFlags = UnitaryFlags.NoFlags,
) -> CFG:
"""Builds a CFG from a list of ast nodes.

We also require the expected number of return ports for the whole CFG. This is
needed to translate return statements into assignments of dummy return
variables.
"""
self.cfg = CFG()
self.cfg.unitary_flags = unitary_flags
self.globals = globals

final_bb = self.visit_stmts(
Expand Down Expand Up @@ -273,6 +280,7 @@ def visit_FunctionDef(

func_ty = check_signature(node, self.globals)
returns_none = isinstance(func_ty.output, NoneType)
# No UnitaryFlags are assigned to nested functions
cfg = CFGBuilder().build(node.body, returns_none, self.globals)

new_node = NestedFunctionDef(
Expand Down Expand Up @@ -300,6 +308,14 @@ def visit_With(self, node: ast.With, bb: BB, jumps: Jumps) -> BB | None:
modifier = self._handle_withitem(item)
new_node.push_modifier(modifier)

# FIXME: Currently, the unitary flags is not set correctly if there are nested
# `with` blocks. This is because the outer block's unitary flags are not
# propagated to the outer block. The following line should calculate the sum
# of the unitary flags of the outer block and modifiers applied in this
# `with` block.
unitary_flags = new_node.flags()
object.__setattr__(cfg, "unitary_flags", unitary_flags)

set_location_from(new_node, node)
bb.statements.append(new_node)
return bb
Expand Down
3 changes: 3 additions & 0 deletions guppylang-internals/src/guppylang_internals/cfg/cfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
)
from guppylang_internals.cfg.bb import BB, BBStatement, VariableStats
from guppylang_internals.nodes import InoutReturnSentinel
from guppylang_internals.tys.ty import UnitaryFlags

T = TypeVar("T", bound=BB)

Expand All @@ -29,6 +30,7 @@ class BaseCFG(Generic[T]):

#: Set of variables defined in this CFG
assigned_somewhere: set[str]
unitary_flags: UnitaryFlags

def __init__(
self, bbs: list[T], entry_bb: T | None = None, exit_bb: T | None = None
Expand All @@ -42,6 +44,7 @@ def __init__(
self.ass_before = {}
self.maybe_ass_before = {}
self.assigned_somewhere = set()
self.unitary_flags = UnitaryFlags.NoFlags

def ancestors(self, *bbs: T) -> Iterator[T]:
"""Returns an iterator over all ancestors of the given BBs in BFS order."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,11 +149,17 @@ def check_cfg(
checked_cfg.maybe_ass_before = {
compiled[bb]: cfg.maybe_ass_before[bb] for bb in required_bbs
}
checked_cfg.unitary_flags = cfg.unitary_flags

# Finally, run the linearity check
from guppylang_internals.checker.linearity_checker import check_cfg_linearity

linearity_checked_cfg = check_cfg_linearity(checked_cfg, func_name, globals)

from guppylang_internals.checker.unitary_checker import check_cfg_unitary

check_cfg_unitary(linearity_checked_cfg, cfg.unitary_flags)

return linearity_checked_cfg


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from guppylang_internals.checker.cfg_checker import CheckedCFG, check_cfg
from guppylang_internals.checker.core import Context, Globals, Place, Variable
from guppylang_internals.checker.errors.generic import UnsupportedError
from guppylang_internals.checker.unitary_checker import check_invalid_under_dagger
from guppylang_internals.definition.common import DefId
from guppylang_internals.definition.ty import TypeDef
from guppylang_internals.diagnostic import Error, Help, Note
Expand All @@ -37,6 +38,7 @@
InputFlags,
NoneType,
Type,
UnitaryFlags,
unify,
)

Expand Down Expand Up @@ -136,7 +138,8 @@ def check_global_func_def(
returns_none = isinstance(ty.output, NoneType)
assert ty.input_names is not None

cfg = CFGBuilder().build(func_def.body, returns_none, globals)
check_invalid_under_dagger(func_def, ty.unitary_flags)
cfg = CFGBuilder().build(func_def.body, returns_none, globals, ty.unitary_flags)
inputs = [
Variable(x, inp.ty, loc, inp.flags, is_func_input=True)
for x, inp, loc in zip(ty.input_names, ty.inputs, args, strict=True)
Expand All @@ -150,7 +153,9 @@ def check_global_func_def(


def check_nested_func_def(
func_def: NestedFunctionDef, bb: BB, ctx: Context
func_def: NestedFunctionDef,
bb: BB,
ctx: Context,
) -> CheckedNestedFunctionDef:
"""Type checks a local (nested) function definition."""
func_ty = check_signature(func_def, ctx.globals)
Expand Down Expand Up @@ -238,7 +243,10 @@ def check_nested_func_def(


def check_signature(
func_def: ast.FunctionDef, globals: Globals, def_id: DefId | None = None
func_def: ast.FunctionDef,
globals: Globals,
def_id: DefId | None = None,
unitary_flags: UnitaryFlags = UnitaryFlags.NoFlags,
) -> FunctionType:
"""Checks the signature of a function definition and returns the corresponding
Guppy type.
Expand Down Expand Up @@ -307,6 +315,7 @@ def check_signature(
output,
input_names,
sorted(param_var_mapping.values(), key=lambda v: v.idx),
unitary_flags=unitary_flags,
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -880,6 +880,7 @@ def live_places_row(
result_cfg.maybe_ass_before = {
checked[bb]: cfg.maybe_ass_before[bb] for bb in cfg.bbs
}
result_cfg.unitary_flags = cfg.unitary_flags
for bb in cfg.bbs:
checked[bb].predecessors = [checked[pred] for pred in bb.predecessors]
checked[bb].successors = [checked[succ] for succ in bb.successors]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def check_modified_block(
# This name could be printed in error messages, for example,
# when the linearity checker fails in the modifier body
checked_cfg = check_cfg(cfg, inputs, NoneType(), {}, "__modified__()", globals)
func_ty = check_modified_block_signature(checked_cfg.input_tys)
func_ty = check_modified_block_signature(modified_block, checked_cfg.input_tys)

checked_modifier = CheckedModifiedBlock(
def_id,
Expand All @@ -94,16 +94,20 @@ def _set_inout_if_non_copyable(var: Variable) -> Variable:
return var


def check_modified_block_signature(input_tys: list[Type]) -> FunctionType:
def check_modified_block_signature(
modified_block: ModifiedBlock, input_tys: list[Type]
) -> FunctionType:
"""Check and create the signature of a function definition for a body
of a `With` block."""
unitary_flags = modified_block.flags()

func_ty = FunctionType(
[
FuncInput(t, InputFlags.Inout if not t.copyable else InputFlags.NoFlags)
for t in input_tys
],
NoneType(),
unitary_flags=unitary_flags,
)
return func_ty

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
import ast
from typing import Any

from guppylang_internals.ast_util import find_nodes, get_type, loop_in_ast
from guppylang_internals.checker.cfg_checker import CheckedBB, CheckedCFG
from guppylang_internals.checker.core import Place, contains_subscript
from guppylang_internals.checker.errors.generic import (
InvalidUnderDagger,
UnsupportedError,
)
from guppylang_internals.definition.value import CallableDef
from guppylang_internals.engine import ENGINE
from guppylang_internals.error import GuppyError, GuppyTypeError
from guppylang_internals.nodes import (
AnyCall,
BarrierExpr,
GlobalCall,
LocalCall,
PlaceNode,
ResultExpr,
StateResultExpr,
TensorCall,
)
from guppylang_internals.tys.errors import UnitaryCallError
from guppylang_internals.tys.qubit import contain_qubit_ty
from guppylang_internals.tys.ty import FunctionType, UnitaryFlags


def check_invalid_under_dagger(
fn_def: ast.FunctionDef, unitary_flags: UnitaryFlags
) -> None:
"""Check that there are no invalid constructs in a daggered CFG.
This checker checks the case the UnitaryFlags is given by
annotation (i.e., not inferred from `with dagger:`).
"""
if UnitaryFlags.Dagger not in unitary_flags:
return

for stmt in fn_def.body:
loops = loop_in_ast(stmt)
if len(loops) != 0:
loop = next(iter(loops))
err = InvalidUnderDagger(loop, "Loop")
raise GuppyError(err)
# Note: sub-diagnostic for dagger context is not available here

found = find_nodes(
lambda n: isinstance(n, ast.Assign | ast.AnnAssign | ast.AugAssign),
stmt,
{ast.FunctionDef},
)
if len(found) != 0:
assign = next(iter(found))
err = InvalidUnderDagger(assign, "Assignment")
raise GuppyError(err)


class BBUnitaryChecker(ast.NodeVisitor):
flags: UnitaryFlags

"""AST visitor that checks whether the modifiers (dagger, control, power)
are applicable."""

def check(self, bb: CheckedBB[Place], unitary_flags: UnitaryFlags) -> None:
self.flags = unitary_flags
for stmt in bb.statements:
self.visit(stmt)

def _check_classical_args(self, args: list[ast.expr]) -> bool:
for arg in args:
self.visit(arg)
if contain_qubit_ty(get_type(arg)):
return False
return True

def _check_call(self, node: AnyCall, ty: FunctionType) -> None:
classic = self._check_classical_args(node.args)
flag_ok = self.flags in ty.unitary_flags
if not classic and not flag_ok:
raise GuppyTypeError(
UnitaryCallError(node, self.flags & (~ty.unitary_flags))
)

def visit_GlobalCall(self, node: GlobalCall) -> None:
func = ENGINE.get_parsed(node.def_id)
assert isinstance(func, CallableDef)
self._check_call(node, func.ty)

def visit_LocalCall(self, node: LocalCall) -> None:
func = get_type(node.func)
assert isinstance(func, FunctionType)
self._check_call(node, func)

def visit_TensorCall(self, node: TensorCall) -> None:
self._check_call(node, node.tensor_ty)

def visit_BarrierExpr(self, node: BarrierExpr) -> None:
# Barrier is always allowed
pass

def visit_ResultExpr(self, node: ResultExpr) -> None:
# Result is always allowed
pass

def visit_StateResultExpr(self, node: StateResultExpr) -> None:
# StateResult is always allowed
pass

def _check_assign(self, node: Any) -> None:
assert isinstance(node, ast.Assign | ast.AnnAssign | ast.AugAssign)
if UnitaryFlags.Dagger in self.flags:
raise GuppyError(InvalidUnderDagger(node, "Assignment"))
if node.value is not None:
self.visit(node.value)

def visit_AnnAssign(self, node: ast.AnnAssign) -> None:
self._check_assign(node)

def visit_Assign(self, node: ast.Assign) -> None:
self._check_assign(node)

def visit_AugAssign(self, node: ast.AugAssign) -> None:
self._check_assign(node)

def visit_PlaceNode(self, node: PlaceNode) -> None:
if UnitaryFlags.Dagger in self.flags and contains_subscript(node.place):
raise GuppyError(
UnsupportedError(node, "index access", True, "dagger context")
)


def check_cfg_unitary(
cfg: CheckedCFG[Place],
unitary_flags: UnitaryFlags,
) -> None:
bb_checker = BBUnitaryChecker()
for bb in cfg.bbs:
bb_checker.check(bb, unitary_flags)
Loading
Loading