diff --git a/guppylang-internals/src/guppylang_internals/cfg/builder.py b/guppylang-internals/src/guppylang_internals/cfg/builder.py index c29eab383..addda82ea 100644 --- a/guppylang-internals/src/guppylang_internals/cfg/builder.py +++ b/guppylang-internals/src/guppylang_internals/cfg/builder.py @@ -363,11 +363,14 @@ def visit_With(self, node: ast.With, bb: BB, jumps: Jumps) -> BB | None: for item in node.items: item.context_expr, bb = ExprBuilder.build(item.context_expr, self.cfg, bb) modifiers.push(self._handle_withitem(item)) - accumulated_flags = self.cfg.unitary_flags | modifiers.flags() cfg = CFGBuilder().build(node.body, True, self.globals, accumulated_flags) new_node = ModifiedBlock( - cfg=cfg, modifiers=modifiers, **dict(ast.iter_fields(node)) + cfg=cfg, + modifiers=modifiers, + # we save the first modifier node for a better error rendering + first_modifier_node=node.items[0].context_expr, + **dict(ast.iter_fields(node)), ) set_location_from(new_node, node) diff --git a/guppylang-internals/src/guppylang_internals/checker/cfg_checker.py b/guppylang-internals/src/guppylang_internals/checker/cfg_checker.py index d83c7151f..5ca805935 100644 --- a/guppylang-internals/src/guppylang_internals/checker/cfg_checker.py +++ b/guppylang-internals/src/guppylang_internals/checker/cfg_checker.py @@ -76,6 +76,7 @@ def check_cfg( generic_params: dict[str, Parameter], func_name: str, globals: Globals, + first_modifier_node: ast.expr | None = None, ) -> CheckedCFG[Place]: """Type checks a control-flow graph. @@ -154,7 +155,9 @@ def check_cfg( # Finally, run the linearity check from guppylang_internals.checker.linearity_checker import check_cfg_linearity - linearity_checked_cfg = check_cfg_linearity(checked_cfg, func_name, globals) + linearity_checked_cfg = check_cfg_linearity( + checked_cfg, func_name, globals, first_modifier_node=first_modifier_node + ) from guppylang_internals.checker.unitary_checker import check_cfg_unitary diff --git a/guppylang-internals/src/guppylang_internals/checker/errors/linearity.py b/guppylang-internals/src/guppylang_internals/checker/errors/linearity.py index 5468bc0c4..6d47fc42d 100644 --- a/guppylang-internals/src/guppylang_internals/checker/errors/linearity.py +++ b/guppylang-internals/src/guppylang_internals/checker/errors/linearity.py @@ -3,6 +3,7 @@ from __future__ import annotations from dataclasses import dataclass +from enum import Enum, auto from typing import TYPE_CHECKING, ClassVar from guppylang_internals.diagnostic import Error, Help, Note @@ -162,27 +163,41 @@ class Fix(Help): ) +class InCallArg(Enum): + NonCall = auto() + Call = auto() + ModifierCall = auto() + + @dataclass(frozen=True) class NotOwnedError(Error): title: ClassVar[str] = "Not owned" place: Place kind: UseKind - is_call_arg: bool + is_call_arg: InCallArg func_name: str | None calling_func_name: str @property def rendered_span_label(self) -> str: - if self.is_call_arg: + if self.is_call_arg != InCallArg.NonCall: f = f"Function `{self.func_name}`" if self.func_name else "Function" + base_message = f"{f} wants to take ownership of this argument, but" + if self.is_call_arg == InCallArg.Call: + return ( + f"{base_message} " + f"`{self.calling_func_name}` doesn't own `{self.place}`" + ) + elif self.is_call_arg == InCallArg.ModifierCall: + return ( + f"{base_message} we cannot transfer " + f"ownership inside a modifier body" + ) + else: return ( - f"{f} wants to take ownership of this argument, but " - f"`{self.calling_func_name}` doesn't own `{self.place}`" + f"Cannot {self.kind.indicative} `{self.place}` since " + f"`{self.calling_func_name}` doesn't own it" ) - return ( - f"Cannot {self.kind.indicative} `{self.place}` since " - f"`{self.calling_func_name}` doesn't own it" - ) @dataclass(frozen=True) class MakeOwned(Help): @@ -197,6 +212,14 @@ class MakeCopy(Help): "Or consider copying this argument: `{place}.copy()`" ) + @dataclass(frozen=True) + class DefinedHere(Note): + span_label: ClassVar[str] = "Argument `{place.root.name}` defined here ..." + + @dataclass(frozen=True) + class ModifierBlock(Note): + span_label: ClassVar[str] = "outside the modifier block" + @dataclass(frozen=True) class MoveOutOfSubscriptError(Error): diff --git a/guppylang-internals/src/guppylang_internals/checker/linearity_checker.py b/guppylang-internals/src/guppylang_internals/checker/linearity_checker.py index 3647550f2..9a066e0ca 100644 --- a/guppylang-internals/src/guppylang_internals/checker/linearity_checker.py +++ b/guppylang-internals/src/guppylang_internals/checker/linearity_checker.py @@ -35,6 +35,7 @@ BorrowSubPlaceUsedError, ComprAlreadyUsedError, DropAfterCallError, + InCallArg, MoveOutOfSubscriptError, NonCopyableCaptureError, NonCopyablePartialApplyError, @@ -197,6 +198,7 @@ class BBLinearityChecker(ast.NodeVisitor): func_name: str func_inputs: dict[PlaceId, Variable] globals: Globals + first_modifier_node: ast.expr | None def check( self, @@ -205,6 +207,7 @@ def check( func_name: str, func_inputs: dict[PlaceId, Variable], globals: Globals, + first_modifier_node: ast.expr | None = None, ) -> Scope: # Manufacture a scope that holds all places that are live at the start # of this BB @@ -215,6 +218,7 @@ def check( self.func_name = func_name self.func_inputs = func_inputs self.globals = globals + self.first_modifier_node = first_modifier_node # Open up a new nested scope to check the BB contents. This way we can track # when we use variables from the outside vs ones assigned in this BB. The only @@ -247,19 +251,33 @@ def visit_PlaceNode( # `_visit_call_args` helper will set `use_kind=UseKind.BORROW`. is_inout_arg = use_kind == UseKind.BORROW if is_inout_var(node.place) and not is_inout_arg: + if self.first_modifier_node: + call_kind = InCallArg.ModifierCall + elif is_call_arg: + call_kind = InCallArg.Call + else: + call_kind = InCallArg.NonCall + err: Error = NotOwnedError( node, node.place, use_kind, - is_call_arg is not None, + call_kind, self._call_name(is_call_arg), self.func_name, ) arg_span = self.func_inputs[node.place.root.id].defined_at - err.add_sub_diagnostic(NotOwnedError.MakeOwned(arg_span)) - # If the argument is a classical array, we can also suggest copying it. - if has_explicit_copy(node.place.ty): - err.add_sub_diagnostic(NotOwnedError.MakeCopy(node)) + if self.first_modifier_node: + # If we are under a modifier we need a special error message + err.add_sub_diagnostic(NotOwnedError.DefinedHere(arg_span)) + err.add_sub_diagnostic( + NotOwnedError.ModifierBlock(self.first_modifier_node) + ) + else: + err.add_sub_diagnostic(NotOwnedError.MakeOwned(arg_span)) + # If the argument is a classical array, we can also suggest copying it. + if has_explicit_copy(node.place.ty): + err.add_sub_diagnostic(NotOwnedError.MakeCopy(node)) raise GuppyError(err) # Places involving subscripts are handled differently since we ignore everything # after the subscript for the purposes of linearity checking. @@ -804,7 +822,10 @@ def is_simple_literal_subscript( def check_cfg_linearity( - cfg: "CheckedCFG[Variable]", func_name: str, globals: Globals + cfg: "CheckedCFG[Variable]", + func_name: str, + globals: Globals, + first_modifier_node: ast.expr | None = None, ) -> "CheckedCFG[Place]": """Checks whether a CFG satisfies the linearity requirements. @@ -822,6 +843,7 @@ def check_cfg_linearity( func_name=func_name, func_inputs=func_inputs, globals=globals, + first_modifier_node=first_modifier_node, ) for bb in cfg.bbs } diff --git a/guppylang-internals/src/guppylang_internals/checker/modifier_checker.py b/guppylang-internals/src/guppylang_internals/checker/modifier_checker.py index 2edb6a9f2..15ec382d9 100644 --- a/guppylang-internals/src/guppylang_internals/checker/modifier_checker.py +++ b/guppylang-internals/src/guppylang_internals/checker/modifier_checker.py @@ -66,11 +66,15 @@ def check_modified_block( inputs = non_copyable_front_others_back(inputs) def_id = DefId.fresh() globals = ctx.globals - - # TODO: Ad hoc name for the new function - # This name could be printed in error messages, for example, - # when the linearity checker fails in the modifier body - checked_cfg = check_cfg(cfg, inputs, NoneType(), {}, "__modified__()", globals) + checked_cfg = check_cfg( + cfg, + inputs, + NoneType(), + {}, + "__modified__()", + globals, + first_modifier_node=modified_block.first_modifier_node, + ) func_ty = check_modified_block_signature(modified_block, checked_cfg.input_tys) checked_modifier = CheckedModifiedBlock( diff --git a/guppylang-internals/src/guppylang_internals/nodes.py b/guppylang-internals/src/guppylang_internals/nodes.py index 73459c798..1d4269fb6 100644 --- a/guppylang-internals/src/guppylang_internals/nodes.py +++ b/guppylang-internals/src/guppylang_internals/nodes.py @@ -783,13 +783,20 @@ def flags(self) -> UnitaryFlags: class ModifiedBlock(ast.With): cfg: "CFG" + first_modifier_node: ast.expr def __init__( - self, cfg: "CFG", modifiers: "Modifiers", *args: Any, **kwargs: Any + self, + cfg: "CFG", + modifiers: "Modifiers", + first_modifier_node: ast.expr, + *args: Any, + **kwargs: Any, ) -> None: super().__init__(*args, **kwargs) self.cfg = cfg self.modifiers = modifiers + self.first_modifier_node = first_modifier_node @property def dagger(self) -> list[Dagger]: diff --git a/tests/error/modifier_errors/captured_var_inout_own.err b/tests/error/modifier_errors/captured_var_inout_own.err deleted file mode 100644 index 005d590e1..000000000 --- a/tests/error/modifier_errors/captured_var_inout_own.err +++ /dev/null @@ -1,16 +0,0 @@ -Error: Not owned (at $FILE:14:16) - | -12 | a = qubit() -13 | with dagger: -14 | discard(a) - | ^ Function `discard` wants to take ownership of this argument, - | but `__modified__()` doesn't own `a` - -Note: - | -11 | def test() -> None: -12 | a = qubit() - | - Argument `a` is only borrowed. Consider taking ownership: - | `a: qubit @owned` - -Guppy compilation failed due to 1 previous error diff --git a/tests/error/modifier_errors/captured_var_inout_own1.err b/tests/error/modifier_errors/captured_var_inout_own1.err new file mode 100644 index 000000000..1a509f63c --- /dev/null +++ b/tests/error/modifier_errors/captured_var_inout_own1.err @@ -0,0 +1,17 @@ +Error: Not owned (at $FILE:16:16) + | +14 | with power(2): +15 | pass +16 | discard(a) + | ^ Function `discard` wants to take ownership of this argument, + | but we cannot transfer ownership inside a modifier body + +Notes: + | +11 | def test() -> None: +12 | a = qubit() + | - Argument `a` defined here ... +13 | with dagger: + | ------ outside the modifier block + +Guppy compilation failed due to 1 previous error diff --git a/tests/error/modifier_errors/captured_var_inout_own.py b/tests/error/modifier_errors/captured_var_inout_own1.py similarity index 71% rename from tests/error/modifier_errors/captured_var_inout_own.py rename to tests/error/modifier_errors/captured_var_inout_own1.py index 211baef94..0ff4b815b 100644 --- a/tests/error/modifier_errors/captured_var_inout_own.py +++ b/tests/error/modifier_errors/captured_var_inout_own1.py @@ -1,17 +1,17 @@ from guppylang.decorator import guppy from guppylang.std.quantum import qubit, owned +from tests.util import compile_guppy @guppy.declare(dagger=True) def discard(q: qubit @ owned) -> None: ... -# TODO: The error message is not prefect. -@guppy +@compile_guppy def test() -> None: a = qubit() with dagger: + with power(2): + pass discard(a) - -test.compile() diff --git a/tests/error/modifier_errors/captured_var_inout_own2.err b/tests/error/modifier_errors/captured_var_inout_own2.err new file mode 100644 index 000000000..130825131 --- /dev/null +++ b/tests/error/modifier_errors/captured_var_inout_own2.err @@ -0,0 +1,18 @@ +Error: Not owned (at $FILE:15:20) + | +13 | with dagger: +14 | with power(3): +15 | discard(a) + | ^ Function `discard` wants to take ownership of this argument, + | but we cannot transfer ownership inside a modifier body + +Notes: + | +11 | def test() -> None: +12 | a = qubit() + | - Argument `a` defined here ... +13 | with dagger: +14 | with power(3): + | -------- outside the modifier block + +Guppy compilation failed due to 1 previous error diff --git a/tests/error/modifier_errors/captured_var_inout_own2.py b/tests/error/modifier_errors/captured_var_inout_own2.py new file mode 100644 index 000000000..e341cc43f --- /dev/null +++ b/tests/error/modifier_errors/captured_var_inout_own2.py @@ -0,0 +1,17 @@ +from guppylang.decorator import guppy +from guppylang.std.quantum import qubit, owned +from tests.util import compile_guppy + + +@guppy.declare(dagger=True) +def discard(q: qubit @ owned) -> None: ... + + +@compile_guppy +def test() -> None: + a = qubit() + with dagger: + with power(3): + discard(a) + +