Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions guppylang-internals/src/guppylang_internals/cfg/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,11 +363,14 @@ def visit_With(self, node: ast.With, bb: BB, jumps: Jumps) -> BB | None:
for item in node.items:
item.context_expr, bb = ExprBuilder.build(item.context_expr, self.cfg, bb)
modifiers.push(self._handle_withitem(item))

accumulated_flags = self.cfg.unitary_flags | modifiers.flags()
cfg = CFGBuilder().build(node.body, True, self.globals, accumulated_flags)
new_node = ModifiedBlock(
cfg=cfg, modifiers=modifiers, **dict(ast.iter_fields(node))
cfg=cfg,
modifiers=modifiers,
# we save the first modifier node for a better error rendering
first_modifier_node=node.items[0].context_expr,
**dict(ast.iter_fields(node)),
)

set_location_from(new_node, node)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def check_cfg(
generic_params: dict[str, Parameter],
func_name: str,
globals: Globals,
first_modifier_node: ast.expr | None = None,
) -> CheckedCFG[Place]:
"""Type checks a control-flow graph.

Expand Down Expand Up @@ -154,7 +155,9 @@ def check_cfg(
# Finally, run the linearity check
from guppylang_internals.checker.linearity_checker import check_cfg_linearity

linearity_checked_cfg = check_cfg_linearity(checked_cfg, func_name, globals)
linearity_checked_cfg = check_cfg_linearity(
checked_cfg, func_name, globals, first_modifier_node=first_modifier_node
)

from guppylang_internals.checker.unitary_checker import check_cfg_unitary

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

from dataclasses import dataclass
from enum import Enum, auto
from typing import TYPE_CHECKING, ClassVar

from guppylang_internals.diagnostic import Error, Help, Note
Expand Down Expand Up @@ -162,27 +163,41 @@ class Fix(Help):
)


class InCallArg(Enum):
NonCall = auto()
Call = auto()
ModifierCall = auto()


@dataclass(frozen=True)
class NotOwnedError(Error):
title: ClassVar[str] = "Not owned"
place: Place
kind: UseKind
is_call_arg: bool
is_call_arg: InCallArg
func_name: str | None
calling_func_name: str

@property
def rendered_span_label(self) -> str:
if self.is_call_arg:
if self.is_call_arg != InCallArg.NonCall:
f = f"Function `{self.func_name}`" if self.func_name else "Function"
base_message = f"{f} wants to take ownership of this argument, but"
if self.is_call_arg == InCallArg.Call:
return (
f"{base_message} "
f"`{self.calling_func_name}` doesn't own `{self.place}`"
)
elif self.is_call_arg == InCallArg.ModifierCall:
return (
f"{base_message} we cannot transfer "
f"ownership inside a modifier body"
)
else:
return (
f"{f} wants to take ownership of this argument, but "
f"`{self.calling_func_name}` doesn't own `{self.place}`"
f"Cannot {self.kind.indicative} `{self.place}` since "
f"`{self.calling_func_name}` doesn't own it"
)
return (
f"Cannot {self.kind.indicative} `{self.place}` since "
f"`{self.calling_func_name}` doesn't own it"
)

@dataclass(frozen=True)
class MakeOwned(Help):
Expand All @@ -197,6 +212,14 @@ class MakeCopy(Help):
"Or consider copying this argument: `{place}.copy()`"
)

@dataclass(frozen=True)
class DefinedHere(Note):
span_label: ClassVar[str] = "Argument `{place.root.name}` defined here ..."

@dataclass(frozen=True)
class ModifierBlock(Note):
span_label: ClassVar[str] = "outside the modifier block"


@dataclass(frozen=True)
class MoveOutOfSubscriptError(Error):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
BorrowSubPlaceUsedError,
ComprAlreadyUsedError,
DropAfterCallError,
InCallArg,
MoveOutOfSubscriptError,
NonCopyableCaptureError,
NonCopyablePartialApplyError,
Expand Down Expand Up @@ -197,6 +198,7 @@ class BBLinearityChecker(ast.NodeVisitor):
func_name: str
func_inputs: dict[PlaceId, Variable]
globals: Globals
first_modifier_node: ast.expr | None

def check(
self,
Expand All @@ -205,6 +207,7 @@ def check(
func_name: str,
func_inputs: dict[PlaceId, Variable],
globals: Globals,
first_modifier_node: ast.expr | None = None,
) -> Scope:
# Manufacture a scope that holds all places that are live at the start
# of this BB
Expand All @@ -215,6 +218,7 @@ def check(
self.func_name = func_name
self.func_inputs = func_inputs
self.globals = globals
self.first_modifier_node = first_modifier_node

# Open up a new nested scope to check the BB contents. This way we can track
# when we use variables from the outside vs ones assigned in this BB. The only
Expand Down Expand Up @@ -247,19 +251,33 @@ def visit_PlaceNode(
# `_visit_call_args` helper will set `use_kind=UseKind.BORROW`.
is_inout_arg = use_kind == UseKind.BORROW
if is_inout_var(node.place) and not is_inout_arg:
if self.first_modifier_node:
call_kind = InCallArg.ModifierCall
elif is_call_arg:
call_kind = InCallArg.Call
else:
call_kind = InCallArg.NonCall

err: Error = NotOwnedError(
node,
node.place,
use_kind,
is_call_arg is not None,
call_kind,
self._call_name(is_call_arg),
self.func_name,
)
arg_span = self.func_inputs[node.place.root.id].defined_at
err.add_sub_diagnostic(NotOwnedError.MakeOwned(arg_span))
# If the argument is a classical array, we can also suggest copying it.
if has_explicit_copy(node.place.ty):
err.add_sub_diagnostic(NotOwnedError.MakeCopy(node))
if self.first_modifier_node:
# If we are under a modifier we need a special error message
err.add_sub_diagnostic(NotOwnedError.DefinedHere(arg_span))
err.add_sub_diagnostic(
NotOwnedError.ModifierBlock(self.first_modifier_node)
)
else:
err.add_sub_diagnostic(NotOwnedError.MakeOwned(arg_span))
# If the argument is a classical array, we can also suggest copying it.
if has_explicit_copy(node.place.ty):
err.add_sub_diagnostic(NotOwnedError.MakeCopy(node))
raise GuppyError(err)
# Places involving subscripts are handled differently since we ignore everything
# after the subscript for the purposes of linearity checking.
Expand Down Expand Up @@ -804,7 +822,10 @@ def is_simple_literal_subscript(


def check_cfg_linearity(
cfg: "CheckedCFG[Variable]", func_name: str, globals: Globals
cfg: "CheckedCFG[Variable]",
func_name: str,
globals: Globals,
first_modifier_node: ast.expr | None = None,
) -> "CheckedCFG[Place]":
"""Checks whether a CFG satisfies the linearity requirements.

Expand All @@ -822,6 +843,7 @@ def check_cfg_linearity(
func_name=func_name,
func_inputs=func_inputs,
globals=globals,
first_modifier_node=first_modifier_node,
)
for bb in cfg.bbs
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,15 @@ def check_modified_block(
inputs = non_copyable_front_others_back(inputs)
def_id = DefId.fresh()
globals = ctx.globals

# TODO: Ad hoc name for the new function
# This name could be printed in error messages, for example,
# when the linearity checker fails in the modifier body
checked_cfg = check_cfg(cfg, inputs, NoneType(), {}, "__modified__()", globals)
checked_cfg = check_cfg(
cfg,
inputs,
NoneType(),
{},
"__modified__()",
globals,
first_modifier_node=modified_block.first_modifier_node,
)
func_ty = check_modified_block_signature(modified_block, checked_cfg.input_tys)

checked_modifier = CheckedModifiedBlock(
Expand Down
9 changes: 8 additions & 1 deletion guppylang-internals/src/guppylang_internals/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -783,13 +783,20 @@ def flags(self) -> UnitaryFlags:

class ModifiedBlock(ast.With):
cfg: "CFG"
first_modifier_node: ast.expr

def __init__(
self, cfg: "CFG", modifiers: "Modifiers", *args: Any, **kwargs: Any
self,
cfg: "CFG",
modifiers: "Modifiers",
first_modifier_node: ast.expr,
*args: Any,
**kwargs: Any,
) -> None:
super().__init__(*args, **kwargs)
self.cfg = cfg
self.modifiers = modifiers
self.first_modifier_node = first_modifier_node

@property
def dagger(self) -> list[Dagger]:
Expand Down
16 changes: 0 additions & 16 deletions tests/error/modifier_errors/captured_var_inout_own.err

This file was deleted.

17 changes: 17 additions & 0 deletions tests/error/modifier_errors/captured_var_inout_own1.err
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
Error: Not owned (at $FILE:16:16)
|
14 | with power(2):
15 | pass
16 | discard(a)
| ^ Function `discard` wants to take ownership of this argument,
| but we cannot transfer ownership inside a modifier body

Notes:
|
11 | def test() -> None:
12 | a = qubit()
| - Argument `a` defined here ...
13 | with dagger:
| ------ outside the modifier block

Guppy compilation failed due to 1 previous error
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
from guppylang.decorator import guppy
from guppylang.std.quantum import qubit, owned
from tests.util import compile_guppy


@guppy.declare(dagger=True)
def discard(q: qubit @ owned) -> None: ...


# TODO: The error message is not prefect.
@guppy
@compile_guppy
def test() -> None:
a = qubit()
with dagger:
with power(2):
pass
discard(a)


test.compile()
18 changes: 18 additions & 0 deletions tests/error/modifier_errors/captured_var_inout_own2.err
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
Error: Not owned (at $FILE:15:20)
|
13 | with dagger:
14 | with power(3):
15 | discard(a)
| ^ Function `discard` wants to take ownership of this argument,
| but we cannot transfer ownership inside a modifier body

Notes:
|
11 | def test() -> None:
12 | a = qubit()
| - Argument `a` defined here ...
13 | with dagger:
14 | with power(3):
| -------- outside the modifier block

Guppy compilation failed due to 1 previous error
17 changes: 17 additions & 0 deletions tests/error/modifier_errors/captured_var_inout_own2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from guppylang.decorator import guppy
from guppylang.std.quantum import qubit, owned
from tests.util import compile_guppy


@guppy.declare(dagger=True)
def discard(q: qubit @ owned) -> None: ...


@compile_guppy
def test() -> None:
a = qubit()
with dagger:
with power(3):
discard(a)


Loading