From fbc9701753367c6f87d1b9d4bff7b16b45488977 Mon Sep 17 00:00:00 2001 From: hkengo-qtnm Date: Thu, 2 Oct 2025 15:12:12 +0100 Subject: [PATCH 01/13] [wip] compiler for modifier first modifier compiler, with extensions imported modifier_compiler now directly compiles cfgs minor reference to tket repository with new extensions integration test for modifier_compiler minor --- .../compiler/modifier_compiler.py | 192 ++++++++++++++++++ .../compiler/stmt_compiler.py | 8 + .../std/_internal/compiler/tket_exts.py | 8 + pyproject.toml | 2 +- tests/integration/test_modifier.py | 135 ++++++++++++ 5 files changed, 344 insertions(+), 1 deletion(-) create mode 100644 guppylang-internals/src/guppylang_internals/compiler/modifier_compiler.py create mode 100644 tests/integration/test_modifier.py diff --git a/guppylang-internals/src/guppylang_internals/compiler/modifier_compiler.py b/guppylang-internals/src/guppylang_internals/compiler/modifier_compiler.py new file mode 100644 index 000000000..b00cafb5f --- /dev/null +++ b/guppylang-internals/src/guppylang_internals/compiler/modifier_compiler.py @@ -0,0 +1,192 @@ +"""Hugr generation for modifiers.""" + +from hugr import Wire, ops +from hugr import tys as ht + +from guppylang_internals.ast_util import get_type +from guppylang_internals.checker.modifier_checker import non_copyable_front_others_back +from guppylang_internals.compiler.cfg_compiler import compile_cfg +from guppylang_internals.compiler.core import CompilerContext, DFContainer +from guppylang_internals.compiler.expr_compiler import ( + ExprCompiler, + array_unwrap_elem, + array_wrap_elem, +) +from guppylang_internals.nodes import CheckedModifiedBlock, PlaceNode +from guppylang_internals.std._internal.compiler.array import ( + array_convert_from_std_array, + array_convert_to_std_array, + array_map, + array_new, + standard_array_type, + unpack_array, +) +from guppylang_internals.std._internal.compiler.tket_exts import MODIFIER_EXTENSION +from guppylang_internals.tys.builtin import int_type, is_array_type +from guppylang_internals.tys.ty import InputFlags + + +def compile_modified_block( + modified_block: CheckedModifiedBlock, + dfg: DFContainer, + ctx: CompilerContext, + expr_compiler: ExprCompiler, +) -> Wire: + DAGGER_OP_NAME = "DaggerModifier" + CONTROL_OP_NAME = "ControlModifier" + POWER_OP_NAME = "PowerModifier" + + # Define types + body_ty = modified_block.ty + # TODO: Shouldn't this be `to_hugr_poly`? + hugr_ty = body_ty.to_hugr(ctx) + in_out_ht = [ + fn_inp.ty.to_hugr(ctx) + for fn_inp in body_ty.inputs + if InputFlags.Inout in fn_inp.flags and InputFlags.Comptime not in fn_inp.flags + ] + other_in_ht = [ + fn_inp.ty.to_hugr(ctx) + for fn_inp in body_ty.inputs + if InputFlags.Inout not in fn_inp.flags + and InputFlags.Comptime not in fn_inp.flags + ] + in_out_arg = ht.ListArg([t.type_arg() for t in in_out_ht]) + other_in_arg = ht.ListArg([t.type_arg() for t in other_in_ht]) + + func_builder = dfg.builder.module_root_builder().define_function( + str(modified_block), hugr_ty.input, hugr_ty.output + ) + + # compile body + cfg = compile_cfg(modified_block.cfg, func_builder, func_builder.inputs(), ctx) + func_builder.set_outputs(*cfg) + + # LoadFunc + call = dfg.builder.load_function(func_builder, hugr_ty) + + # Function inputs + captured = [v for v, _ in modified_block.captured.values()] + captured = non_copyable_front_others_back(captured) + args = [dfg[v] for v in captured] + + if modified_block.is_dagger(): + dagger_ty = ht.FunctionType([hugr_ty], [hugr_ty]) + call = dfg.builder.add_op( + ops.ExtOp( + MODIFIER_EXTENSION.get_op(DAGGER_OP_NAME), + dagger_ty, + [in_out_arg, other_in_arg], + ), + call, + ) + qubit_num_args = [] + if modified_block.has_control(): + for control in modified_block.control: + # definition of types + assert control.qubit_num is not None + qubit_num: ht.TypeArg + if isinstance(control.qubit_num, int): + qubit_num = ht.BoundedNatArg(control.qubit_num) + else: + qubit_num = control.qubit_num.to_arg().to_hugr(ctx) + qubit_num_args.append(qubit_num) + std_array = standard_array_type(ht.Qubit, qubit_num) + + # control operator + input_fn_ty = hugr_ty + output_fn_ty = ht.FunctionType( + [std_array, *hugr_ty.input], [std_array, *hugr_ty.output] + ) + op = MODIFIER_EXTENSION.get_op(CONTROL_OP_NAME).instantiate( + [qubit_num, in_out_arg, other_in_arg], + ht.FunctionType([input_fn_ty], [output_fn_ty]), + ) + call = dfg.builder.add_op(op, call) + # update types + in_out_arg = ht.ListArg([std_array.type_arg(), *in_out_arg.elems]) + hugr_ty = output_fn_ty + if modified_block.is_power(): + power_ty = ht.FunctionType([hugr_ty, int_type().to_hugr(ctx)], [hugr_ty]) + for power in modified_block.power: + num = expr_compiler.compile(power.iter, dfg) + call = dfg.builder.add_op( + ops.ExtOp( + MODIFIER_EXTENSION.get_op(POWER_OP_NAME), + power_ty, + [in_out_arg, other_in_arg], + ), + call, + num, + ) + + # Prepare control arguments + ctrl_args: list[Wire] = [] + for i, control in enumerate(modified_block.control): + if is_array_type(get_type(control.ctrl[0])): + input_array = expr_compiler.compile(control.ctrl[0], dfg) + + unwrap = array_unwrap_elem(ctx) + unwrap = dfg.builder.load_function( + unwrap, + instantiation=ht.FunctionType([ht.Option(ht.Qubit)], [ht.Qubit]), + type_args=[ht.TypeTypeArg(ht.Qubit)], + ) + map_op = array_map(ht.Option(ht.Qubit), qubit_num_args[i], ht.Qubit) + unwrapped_array = dfg.builder.add_op(map_op, input_array, unwrap) + + unwrapped_array = dfg.builder.add_op( + array_convert_to_std_array(ht.Qubit, qubit_num_args[i]), unwrapped_array + ) + + ctrl_args.extend(unwrapped_array) + else: + cs = [expr_compiler.compile(c, dfg) for c in control.ctrl] + c_node = dfg.builder.add_op(array_new(ht.Qubit, len(control.ctrl)), *cs) + val_to_std = array_convert_to_std_array(ht.Qubit, qubit_num_args[i]) + c_node = dfg.builder.add_op(val_to_std, *c_node) + ctrl_args.append(c_node) + + # Call + call = dfg.builder.add_op( + ops.CallIndirect(), + call, + *ctrl_args, + *args, + ) + outports = iter(call) + + # Unpack controls + for i, control in enumerate(modified_block.control): + outport = next(outports) + if is_array_type(get_type(control.ctrl[0])): + result_array = dfg.builder.add_op( + array_convert_from_std_array(ht.Qubit, qubit_num_args[i]), outport + ) + + wrap = array_wrap_elem(ctx) + wrap = dfg.builder.load_function( + wrap, + instantiation=ht.FunctionType([ht.Qubit], [ht.Option(ht.Qubit)]), + type_args=[ht.TypeTypeArg(ht.Qubit)], + ) + map_op = array_map(ht.Qubit, qubit_num_args[i], ht.Option(ht.Qubit)) + new_c = dfg.builder.add_op(map_op, result_array, wrap) + + c = control.ctrl[0] + assert isinstance(c, PlaceNode) + + dfg[c.place] = new_c + else: + val_from_std = array_convert_from_std_array(ht.Qubit, qubit_num_args[i]) + std_arr = dfg.builder.add_op(val_from_std, outport) + unpacked = unpack_array(dfg.builder, std_arr) + for c, wire in zip(control.ctrl, unpacked, strict=False): + assert isinstance(c, PlaceNode) + dfg[c.place] = wire + + for arg in captured: + if InputFlags.Inout in arg.flags: + dfg[arg] = next(outports) + + return call diff --git a/guppylang-internals/src/guppylang_internals/compiler/stmt_compiler.py b/guppylang-internals/src/guppylang_internals/compiler/stmt_compiler.py index 2a5e9031c..036ac00d1 100644 --- a/guppylang-internals/src/guppylang_internals/compiler/stmt_compiler.py +++ b/guppylang-internals/src/guppylang_internals/compiler/stmt_compiler.py @@ -18,6 +18,7 @@ from guppylang_internals.error import InternalGuppyError from guppylang_internals.nodes import ( ArrayUnpack, + CheckedModifiedBlock, CheckedNestedFunctionDef, IterableUnpack, PlaceNode, @@ -221,3 +222,10 @@ def visit_CheckedNestedFunctionDef(self, node: CheckedNestedFunctionDef) -> None var = Variable(node.name, node.ty, node) loaded_func = compile_local_func_def(node, self.dfg, self.ctx) self.dfg[var] = loaded_func + + def visit_CheckedModifiedBlock(self, node: CheckedModifiedBlock) -> None: + from guppylang_internals.compiler.modifier_compiler import ( + compile_modified_block, + ) + + compile_modified_block(node, self.dfg, self.ctx, self.expr_compiler) diff --git a/guppylang-internals/src/guppylang_internals/std/_internal/compiler/tket_exts.py b/guppylang-internals/src/guppylang_internals/std/_internal/compiler/tket_exts.py index fe17ea0eb..f6c2db294 100644 --- a/guppylang-internals/src/guppylang_internals/std/_internal/compiler/tket_exts.py +++ b/guppylang-internals/src/guppylang_internals/std/_internal/compiler/tket_exts.py @@ -5,7 +5,10 @@ from tket_exts import ( debug, futures, + global_phase, guppy, + modifier, + opaque_bool, qsystem, qsystem_random, qsystem_utils, @@ -19,6 +22,7 @@ DEBUG_EXTENSION = debug() FUTURES_EXTENSION = futures() GUPPY_EXTENSION = guppy() +MODIFIER_EXTENSION = modifier() QSYSTEM_EXTENSION = qsystem() QSYSTEM_RANDOM_EXTENSION = qsystem_random() QSYSTEM_UTILS_EXTENSION = qsystem_utils() @@ -26,6 +30,8 @@ RESULT_EXTENSION = result() ROTATION_EXTENSION = rotation() WASM_EXTENSION = wasm() +MODIFIER_EXTENSION = modifier() +GLOBAL_PHASE_EXTENSION = global_phase() TKET_EXTENSIONS = [ BOOL_EXTENSION, @@ -39,6 +45,8 @@ RESULT_EXTENSION, ROTATION_EXTENSION, WASM_EXTENSION, + MODIFIER_EXTENSION, + GLOBAL_PHASE_EXTENSION, ] diff --git a/pyproject.toml b/pyproject.toml index 35ba67af2..e5703256d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,7 +31,7 @@ miette-py = { workspace = true } # Uncomment these to test the latest dependency version during development # hugr = { git = "https://github.com/CQCL/hugr", subdirectory = "hugr-py", rev = "50a2bac" } -# tket = { git = "https://github.com/CQCL/tket2", subdirectory = "tket-py", rev = "aca944c" } +tket = { git = "https://github.com/CQCL/tket2", subdirectory = "tket-py", rev = "73ff49b" } [build-system] requires = ["hatchling"] diff --git a/tests/integration/test_modifier.py b/tests/integration/test_modifier.py new file mode 100644 index 000000000..2108d5a85 --- /dev/null +++ b/tests/integration/test_modifier.py @@ -0,0 +1,135 @@ +from guppylang.decorator import guppy +from guppylang.std.quantum import qubit +from guppylang.std.num import nat +from guppylang.std.builtins import owned +from guppylang.std.array import array + +# Dummy variables to suppress Undefined name +# TODO: `ruff` fails when without these, which need to be fixed +dagger = object() +control = object() +power = object() + + +def test_dagger_simple(validate): + @guppy + def bar() -> None: + with dagger: + pass + + validate(bar.compile_function()) + + +def test_dagger_call_simple(validate): + @guppy + def bar() -> None: + with dagger(): + pass + + validate(bar.compile_function()) + + +def test_control_simple(validate): + @guppy + def bar(q: qubit) -> None: + with control(q): + pass + + validate(bar.compile_function()) + + +def test_control_multiple(validate): + @guppy + def bar(q1: qubit, q2: qubit) -> None: + with control(q1, q2): + pass + + validate(bar.compile_function()) + + +def test_control_array(validate): + @guppy + def bar(q: array[qubit, 3]) -> None: + with control(q): + pass + + validate(bar.compile_function()) + + +def test_power_simple(validate): + @guppy + def bar(n: nat) -> None: + with power(n): + pass + + validate(bar.compile_function()) + + +def test_call_in_modifier(validate): + @guppy + def foo() -> None: + pass + + @guppy + def bar() -> None: + with dagger: + foo() + + validate(bar.compile_function()) + + +def test_combined_modifiers(validate): + @guppy + def bar(q: qubit) -> None: + with control(q), power(2), dagger: + pass + + validate(bar.compile_function()) + + +def test_nested_modifiers(validate): + @guppy + def bar(q: qubit) -> None: + with control(q): + with power(2): + with dagger: + pass + + validate(bar.compile_function()) + + +def test_free_linear_variable_in_modifier(validate): + T = guppy.type_var("T", copyable=False, droppable=False) + + @guppy.declare + def use(a: T) -> None: ... + + @guppy.declare + def discard(a: T @ owned) -> None: ... + + @guppy + def bar(q: qubit) -> None: + a = array(qubit()) + with control(q): + use(a) + discard(a) + + validate(bar.compile_function()) + + +def test_free_copyable_variable_in_modifier(validate): + T = guppy.type_var("T", copyable=True, droppable=True) + + @guppy.declare + def use(a: T) -> None: ... + + @guppy.declare + def discard(a: T @ owned) -> None: ... + + @guppy + def bar(q: array[qubit, 3]) -> None: + a = 3 + with control(q): + use(a) + + validate(bar.compile_function()) From ea19be7e08f5adbb171883194ef0db522f22ab1c Mon Sep 17 00:00:00 2001 From: hkengo-qtnm Date: Mon, 6 Oct 2025 11:29:47 +0100 Subject: [PATCH 02/13] first implementation of modifier compiler --- .../compiler/modifier_compiler.py | 119 ++++++++++-------- .../src/guppylang_internals/nodes.py | 4 +- .../std/_internal/compiler/tket_exts.py | 1 - pyproject.toml | 2 +- 4 files changed, 69 insertions(+), 57 deletions(-) diff --git a/guppylang-internals/src/guppylang_internals/compiler/modifier_compiler.py b/guppylang-internals/src/guppylang_internals/compiler/modifier_compiler.py index b00cafb5f..01827477c 100644 --- a/guppylang-internals/src/guppylang_internals/compiler/modifier_compiler.py +++ b/guppylang-internals/src/guppylang_internals/compiler/modifier_compiler.py @@ -1,6 +1,6 @@ """Hugr generation for modifiers.""" -from hugr import Wire, ops +from hugr import Node, Wire, ops from hugr import tys as ht from guppylang_internals.ast_util import get_type @@ -36,9 +36,13 @@ def compile_modified_block( CONTROL_OP_NAME = "ControlModifier" POWER_OP_NAME = "PowerModifier" - # Define types + dagger_op_def = MODIFIER_EXTENSION.get_op(DAGGER_OP_NAME) + control_op_def = MODIFIER_EXTENSION.get_op(CONTROL_OP_NAME) + power_op_def = MODIFIER_EXTENSION.get_op(POWER_OP_NAME) + body_ty = modified_block.ty - # TODO: Shouldn't this be `to_hugr_poly`? + # TODO: Shouldn't this be `to_hugr_poly` since it can contain + # a variable with a generic type? hugr_ty = body_ty.to_hugr(ctx) in_out_ht = [ fn_inp.ty.to_hugr(ctx) @@ -70,20 +74,33 @@ def compile_modified_block( captured = non_copyable_front_others_back(captured) args = [dfg[v] for v in captured] - if modified_block.is_dagger(): + # Apply modifiers + if modified_block.has_dagger(): dagger_ty = ht.FunctionType([hugr_ty], [hugr_ty]) call = dfg.builder.add_op( ops.ExtOp( - MODIFIER_EXTENSION.get_op(DAGGER_OP_NAME), + dagger_op_def, dagger_ty, [in_out_arg, other_in_arg], ), call, ) + if modified_block.has_power(): + power_ty = ht.FunctionType([hugr_ty, int_type().to_hugr(ctx)], [hugr_ty]) + for power in modified_block.power: + num = expr_compiler.compile(power.iter, dfg) + call = dfg.builder.add_op( + ops.ExtOp( + power_op_def, + power_ty, + [in_out_arg, other_in_arg], + ), + call, + num, + ) qubit_num_args = [] if modified_block.has_control(): for control in modified_block.control: - # definition of types assert control.qubit_num is not None qubit_num: ht.TypeArg if isinstance(control.qubit_num, int): @@ -98,54 +115,47 @@ def compile_modified_block( output_fn_ty = ht.FunctionType( [std_array, *hugr_ty.input], [std_array, *hugr_ty.output] ) - op = MODIFIER_EXTENSION.get_op(CONTROL_OP_NAME).instantiate( - [qubit_num, in_out_arg, other_in_arg], + op = ops.ExtOp( + control_op_def, ht.FunctionType([input_fn_ty], [output_fn_ty]), + [qubit_num, in_out_arg, other_in_arg], ) call = dfg.builder.add_op(op, call) # update types in_out_arg = ht.ListArg([std_array.type_arg(), *in_out_arg.elems]) hugr_ty = output_fn_ty - if modified_block.is_power(): - power_ty = ht.FunctionType([hugr_ty, int_type().to_hugr(ctx)], [hugr_ty]) - for power in modified_block.power: - num = expr_compiler.compile(power.iter, dfg) - call = dfg.builder.add_op( - ops.ExtOp( - MODIFIER_EXTENSION.get_op(POWER_OP_NAME), - power_ty, - [in_out_arg, other_in_arg], - ), - call, - num, - ) # Prepare control arguments ctrl_args: list[Wire] = [] + unwrap: Node | None = None for i, control in enumerate(modified_block.control): if is_array_type(get_type(control.ctrl[0])): - input_array = expr_compiler.compile(control.ctrl[0], dfg) - - unwrap = array_unwrap_elem(ctx) - unwrap = dfg.builder.load_function( + control_array = expr_compiler.compile(control.ctrl[0], dfg) + # if `unwrap` function is already loaded, reuse it, otherwise create it + if unwrap is None: + unwrap = dfg.builder.load_function( + array_unwrap_elem(ctx), + instantiation=ht.FunctionType([ht.Option(ht.Qubit)], [ht.Qubit]), + type_args=[ht.TypeTypeArg(ht.Qubit)], + ) + control_array = dfg.builder.add_op( + array_map(ht.Option(ht.Qubit), qubit_num_args[i], ht.Qubit), + control_array, unwrap, - instantiation=ht.FunctionType([ht.Option(ht.Qubit)], [ht.Qubit]), - type_args=[ht.TypeTypeArg(ht.Qubit)], ) - map_op = array_map(ht.Option(ht.Qubit), qubit_num_args[i], ht.Qubit) - unwrapped_array = dfg.builder.add_op(map_op, input_array, unwrap) - - unwrapped_array = dfg.builder.add_op( - array_convert_to_std_array(ht.Qubit, qubit_num_args[i]), unwrapped_array + control_array = dfg.builder.add_op( + array_convert_to_std_array(ht.Qubit, qubit_num_args[i]), control_array ) - - ctrl_args.extend(unwrapped_array) + ctrl_args.append(control_array) else: cs = [expr_compiler.compile(c, dfg) for c in control.ctrl] - c_node = dfg.builder.add_op(array_new(ht.Qubit, len(control.ctrl)), *cs) - val_to_std = array_convert_to_std_array(ht.Qubit, qubit_num_args[i]) - c_node = dfg.builder.add_op(val_to_std, *c_node) - ctrl_args.append(c_node) + control_array = dfg.builder.add_op( + array_new(ht.Qubit, len(control.ctrl)), *cs + ) + control_array = dfg.builder.add_op( + array_convert_to_std_array(ht.Qubit, qubit_num_args[i]), *control_array + ) + ctrl_args.append(control_array) # Call call = dfg.builder.add_op( @@ -157,33 +167,36 @@ def compile_modified_block( outports = iter(call) # Unpack controls + wrap: Node | None = None for i, control in enumerate(modified_block.control): outport = next(outports) if is_array_type(get_type(control.ctrl[0])): - result_array = dfg.builder.add_op( + control_array = dfg.builder.add_op( array_convert_from_std_array(ht.Qubit, qubit_num_args[i]), outport ) - - wrap = array_wrap_elem(ctx) - wrap = dfg.builder.load_function( + if wrap is None: + wrap = dfg.builder.load_function( + array_wrap_elem(ctx), + instantiation=ht.FunctionType([ht.Qubit], [ht.Option(ht.Qubit)]), + type_args=[ht.TypeTypeArg(ht.Qubit)], + ) + control_array = dfg.builder.add_op( + array_map(ht.Qubit, qubit_num_args[i], ht.Option(ht.Qubit)), + control_array, wrap, - instantiation=ht.FunctionType([ht.Qubit], [ht.Option(ht.Qubit)]), - type_args=[ht.TypeTypeArg(ht.Qubit)], ) - map_op = array_map(ht.Qubit, qubit_num_args[i], ht.Option(ht.Qubit)) - new_c = dfg.builder.add_op(map_op, result_array, wrap) c = control.ctrl[0] assert isinstance(c, PlaceNode) - - dfg[c.place] = new_c + dfg[c.place] = control_array else: - val_from_std = array_convert_from_std_array(ht.Qubit, qubit_num_args[i]) - std_arr = dfg.builder.add_op(val_from_std, outport) - unpacked = unpack_array(dfg.builder, std_arr) - for c, wire in zip(control.ctrl, unpacked, strict=False): + control_array = dfg.builder.add_op( + array_convert_from_std_array(ht.Qubit, qubit_num_args[i]), outport + ) + unpacked = unpack_array(dfg.builder, control_array) + for c, new_c in zip(control.ctrl, unpacked, strict=False): assert isinstance(c, PlaceNode) - dfg[c.place] = wire + dfg[c.place] = new_c for arg in captured: if InputFlags.Inout in arg.flags: diff --git a/guppylang-internals/src/guppylang_internals/nodes.py b/guppylang-internals/src/guppylang_internals/nodes.py index e80e4a1fd..e5f1dc2fc 100644 --- a/guppylang-internals/src/guppylang_internals/nodes.py +++ b/guppylang-internals/src/guppylang_internals/nodes.py @@ -538,11 +538,11 @@ def __str__(self) -> str: # generate a function name from the def_id return f"__WithBlock__({self.def_id})" - def is_dagger(self) -> bool: + def has_dagger(self) -> bool: return len(self.dagger) % 2 == 1 def has_control(self) -> bool: return any(len(c.ctrl) > 0 for c in self.control) - def is_power(self) -> bool: + def has_power(self) -> bool: return len(self.power) > 0 diff --git a/guppylang-internals/src/guppylang_internals/std/_internal/compiler/tket_exts.py b/guppylang-internals/src/guppylang_internals/std/_internal/compiler/tket_exts.py index f6c2db294..fb475860b 100644 --- a/guppylang-internals/src/guppylang_internals/std/_internal/compiler/tket_exts.py +++ b/guppylang-internals/src/guppylang_internals/std/_internal/compiler/tket_exts.py @@ -8,7 +8,6 @@ global_phase, guppy, modifier, - opaque_bool, qsystem, qsystem_random, qsystem_utils, diff --git a/pyproject.toml b/pyproject.toml index e5703256d..ce3752060 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,7 +31,7 @@ miette-py = { workspace = true } # Uncomment these to test the latest dependency version during development # hugr = { git = "https://github.com/CQCL/hugr", subdirectory = "hugr-py", rev = "50a2bac" } -tket = { git = "https://github.com/CQCL/tket2", subdirectory = "tket-py", rev = "73ff49b" } +tket = { git = "https://github.com/CQCL/tket2", subdirectory = "tket-py", rev = "f0bc211" } [build-system] requires = ["hatchling"] From 466c3254edd8394bdc4cd8dae714d98373d693b6 Mon Sep 17 00:00:00 2001 From: hkengo-qtnm Date: Wed, 8 Oct 2025 12:14:00 +0100 Subject: [PATCH 03/13] UnitaryFlags -- flags to annotate unitarity some refactoring --- .../src/guppylang_internals/cfg/builder.py | 16 ++- .../src/guppylang_internals/cfg/cfg.py | 3 + .../checker/cfg_checker.py | 5 + .../checker/func_checker.py | 17 ++- .../checker/modifier_checker.py | 10 +- .../checker/unitary_checker.py | 114 ++++++++++++++++++ .../compiler/modifier_compiler.py | 2 + .../src/guppylang_internals/decorator.py | 15 ++- .../guppylang_internals/definition/custom.py | 4 + .../definition/declaration.py | 8 +- .../definition/function.py | 17 ++- .../definition/pytket_circuits.py | 1 + .../src/guppylang_internals/tys/errors.py | 24 +++- .../src/guppylang_internals/tys/qubit.py | 39 +++++- .../src/guppylang_internals/tys/ty.py | 19 +++ guppylang/src/guppylang/decorator.py | 56 +++++++-- 16 files changed, 328 insertions(+), 22 deletions(-) create mode 100644 guppylang-internals/src/guppylang_internals/checker/unitary_checker.py diff --git a/guppylang-internals/src/guppylang_internals/cfg/builder.py b/guppylang-internals/src/guppylang_internals/cfg/builder.py index 4711f19b6..e7d8aa996 100644 --- a/guppylang-internals/src/guppylang_internals/cfg/builder.py +++ b/guppylang-internals/src/guppylang_internals/cfg/builder.py @@ -46,7 +46,7 @@ Power, ) from guppylang_internals.span import Span, to_span -from guppylang_internals.tys.ty import NoneType +from guppylang_internals.tys.ty import NoneType, UnitaryFlags # In order to build expressions, need an endless stream of unique temporary variables # to store intermediate results @@ -78,7 +78,13 @@ class CFGBuilder(AstVisitor[BB | None]): cfg: CFG globals: Globals - def build(self, nodes: list[ast.stmt], returns_none: bool, globals: Globals) -> CFG: + def build( + self, + nodes: list[ast.stmt], + returns_none: bool, + globals: Globals, + uniraty_flags: UnitaryFlags = UnitaryFlags.NoFlags, + ) -> CFG: """Builds a CFG from a list of ast nodes. We also require the expected number of return ports for the whole CFG. This is @@ -86,6 +92,7 @@ def build(self, nodes: list[ast.stmt], returns_none: bool, globals: Globals) -> variables. """ self.cfg = CFG() + self.unitary_flags = uniraty_flags self.globals = globals final_bb = self.visit_stmts( @@ -273,6 +280,7 @@ def visit_FunctionDef( func_ty = check_signature(node, self.globals) returns_none = isinstance(func_ty.output, NoneType) + # No UnitaryFlags are assigned to nested functions cfg = CFGBuilder().build(node.body, returns_none, self.globals) new_node = NestedFunctionDef( @@ -300,6 +308,10 @@ def visit_With(self, node: ast.With, bb: BB, jumps: Jumps) -> BB | None: modifier = self._handle_withitem(item) new_node.push_modifier(modifier) + unitary_flags = UnitaryFlags.NoFlags + # TODO (k.hirata): set unitary_flags properly + object.__setattr__(cfg, "unitary_flags", unitary_flags) + set_location_from(new_node, node) bb.statements.append(new_node) return bb diff --git a/guppylang-internals/src/guppylang_internals/cfg/cfg.py b/guppylang-internals/src/guppylang_internals/cfg/cfg.py index c7e4c8cb3..22b607cd1 100644 --- a/guppylang-internals/src/guppylang_internals/cfg/cfg.py +++ b/guppylang-internals/src/guppylang_internals/cfg/cfg.py @@ -12,6 +12,7 @@ ) from guppylang_internals.cfg.bb import BB, BBStatement, VariableStats from guppylang_internals.nodes import InoutReturnSentinel +from guppylang_internals.tys.ty import UnitaryFlags T = TypeVar("T", bound=BB) @@ -29,6 +30,7 @@ class BaseCFG(Generic[T]): #: Set of variables defined in this CFG assigned_somewhere: set[str] + unitary_flags: UnitaryFlags def __init__( self, bbs: list[T], entry_bb: T | None = None, exit_bb: T | None = None @@ -42,6 +44,7 @@ def __init__( self.ass_before = {} self.maybe_ass_before = {} self.assigned_somewhere = set() + self.unitary_flags = UnitaryFlags.NoFlags def ancestors(self, *bbs: T) -> Iterator[T]: """Returns an iterator over all ancestors of the given BBs in BFS order.""" diff --git a/guppylang-internals/src/guppylang_internals/checker/cfg_checker.py b/guppylang-internals/src/guppylang_internals/checker/cfg_checker.py index 92594a51a..5edee5412 100644 --- a/guppylang-internals/src/guppylang_internals/checker/cfg_checker.py +++ b/guppylang-internals/src/guppylang_internals/checker/cfg_checker.py @@ -154,6 +154,11 @@ def check_cfg( from guppylang_internals.checker.linearity_checker import check_cfg_linearity linearity_checked_cfg = check_cfg_linearity(checked_cfg, func_name, globals) + + from guppylang_internals.checker.unitary_checker import check_cfg_unitary + + check_cfg_unitary(linearity_checked_cfg, cfg.unitary_flags) + return linearity_checked_cfg diff --git a/guppylang-internals/src/guppylang_internals/checker/func_checker.py b/guppylang-internals/src/guppylang_internals/checker/func_checker.py index 59d572761..ea848f678 100644 --- a/guppylang-internals/src/guppylang_internals/checker/func_checker.py +++ b/guppylang-internals/src/guppylang_internals/checker/func_checker.py @@ -37,6 +37,7 @@ InputFlags, NoneType, Type, + UnitaryFlags, unify, ) @@ -136,7 +137,7 @@ def check_global_func_def( returns_none = isinstance(ty.output, NoneType) assert ty.input_names is not None - cfg = CFGBuilder().build(func_def.body, returns_none, globals) + cfg = CFGBuilder().build(func_def.body, returns_none, globals, ty.unitary_flags) inputs = [ Variable(x, inp.ty, loc, inp.flags, is_func_input=True) for x, inp, loc in zip(ty.input_names, ty.inputs, args, strict=True) @@ -150,9 +151,13 @@ def check_global_func_def( def check_nested_func_def( - func_def: NestedFunctionDef, bb: BB, ctx: Context + func_def: NestedFunctionDef, + bb: BB, + ctx: Context, + # unitary_flags: (k.hirata) ) -> CheckedNestedFunctionDef: """Type checks a local (nested) function definition.""" + # unitary_flags: (k.hirata) func_ty = check_signature(func_def, ctx.globals) assert func_ty.input_names is not None @@ -213,6 +218,7 @@ def check_nested_func_def( from guppylang.defs import GuppyDefinition from guppylang_internals.definition.function import ParsedFunctionDef + # TODO (k.hirata): unitary_flags func = ParsedFunctionDef(def_id, func_def.name, func_def, func_ty, None) DEF_STORE.register_def(func, None) ENGINE.parsed[def_id] = func @@ -238,7 +244,10 @@ def check_nested_func_def( def check_signature( - func_def: ast.FunctionDef, globals: Globals, def_id: DefId | None = None + func_def: ast.FunctionDef, + globals: Globals, + def_id: DefId | None = None, + unitary_flags: UnitaryFlags = UnitaryFlags.NoFlags, ) -> FunctionType: """Checks the signature of a function definition and returns the corresponding Guppy type. @@ -247,6 +256,7 @@ def check_signature( passed. This will be used to check or infer the type annotation for the `self` argument. """ + # TODO:(k.hirata) unitary_flags if len(func_def.args.posonlyargs) != 0: raise GuppyError( UnsupportedError(func_def.args.posonlyargs[0], "Positional-only parameters") @@ -307,6 +317,7 @@ def check_signature( output, input_names, sorted(param_var_mapping.values(), key=lambda v: v.idx), + unitary_flags=unitary_flags, ) diff --git a/guppylang-internals/src/guppylang_internals/checker/modifier_checker.py b/guppylang-internals/src/guppylang_internals/checker/modifier_checker.py index c4a3c9271..e5033e9e9 100644 --- a/guppylang-internals/src/guppylang_internals/checker/modifier_checker.py +++ b/guppylang-internals/src/guppylang_internals/checker/modifier_checker.py @@ -16,6 +16,7 @@ InputFlags, NoneType, Type, + UnitaryFlags, ) @@ -71,7 +72,7 @@ def check_modified_block( # This name could be printed in error messages, for example, # when the linearity checker fails in the modifier body checked_cfg = check_cfg(cfg, inputs, NoneType(), {}, "__modified__()", globals) - func_ty = check_modified_block_signature(checked_cfg.input_tys) + func_ty = check_modified_block_signature(modified_block, checked_cfg.input_tys) checked_modifier = CheckedModifiedBlock( def_id, @@ -94,9 +95,13 @@ def _set_inout_if_non_copyable(var: Variable) -> Variable: return var -def check_modified_block_signature(input_tys: list[Type]) -> FunctionType: +def check_modified_block_signature( + modified_block: ModifiedBlock, input_tys: list[Type] +) -> FunctionType: """Check and create the signature of a function definition for a body of a `With` block.""" + # TODO (k.hirata): set unitary flags + unitary_flags = UnitaryFlags.NoFlags func_ty = FunctionType( [ @@ -104,6 +109,7 @@ def check_modified_block_signature(input_tys: list[Type]) -> FunctionType: for t in input_tys ], NoneType(), + unitary_flags=unitary_flags, ) return func_ty diff --git a/guppylang-internals/src/guppylang_internals/checker/unitary_checker.py b/guppylang-internals/src/guppylang_internals/checker/unitary_checker.py new file mode 100644 index 000000000..8dc8b7ca1 --- /dev/null +++ b/guppylang-internals/src/guppylang_internals/checker/unitary_checker.py @@ -0,0 +1,114 @@ +import ast +from typing import Any + +from guppylang_internals.ast_util import get_type +from guppylang_internals.checker.cfg_checker import CheckedBB, CheckedCFG +from guppylang_internals.checker.core import Place, contains_subscript +from guppylang_internals.checker.errors.generic import ( + InvalidUnderDagger, + UnsupportedError, +) +from guppylang_internals.definition.value import CallableDef +from guppylang_internals.engine import ENGINE +from guppylang_internals.error import GuppyError, GuppyTypeError +from guppylang_internals.nodes import ( + AnyCall, + BarrierExpr, + GlobalCall, + LocalCall, + PlaceNode, + ResultExpr, + StateResultExpr, + TensorCall, +) +from guppylang_internals.tys.errors import UnitaryCallError +from guppylang_internals.tys.qubit import contain_qubit_ty +from guppylang_internals.tys.ty import FunctionType, UnitaryFlags + + +class BBUnitaryChecker(ast.NodeVisitor): + flags: UnitaryFlags + + """AST visitor that checks whether the modifiers (dagger, control, power) + are applicable.""" + + def check(self, bb: CheckedBB[Place], unitary_flags: UnitaryFlags) -> None: + self.flags = unitary_flags + for stmt in bb.statements: + self.visit(stmt) + + def _check_classical_args(self, args: list[ast.expr]) -> bool: + for arg in args: + self.visit(arg) + if contain_qubit_ty(get_type(arg)): + return False + return True + + def _check_call(self, node: AnyCall, ty: FunctionType) -> None: + classic = self._check_classical_args(node.args) + flag_ok = self.flags in ty.unitary_flags + if not classic and not flag_ok: + raise GuppyTypeError( + UnitaryCallError(node, self.flags & (~ty.unitary_flags)) + ) + + def visit_GlobalCall(self, node: GlobalCall) -> None: + func = ENGINE.get_parsed(node.def_id) + assert isinstance(func, CallableDef) + self._check_call(node, func.ty) + + def visit_LocalCall(self, node: LocalCall) -> None: + func = get_type(node.func) + assert isinstance(func, FunctionType) + self._check_call(node, func) + + def visit_TensorCall(self, node: TensorCall) -> None: + self._check_call(node, node.tensor_ty) + + def visit_BarrierExpr(self, node: BarrierExpr) -> None: + # Barrier is always allowed + pass + + def visit_ResultExpr(self, node: ResultExpr) -> None: + # Result is always allowed + pass + + def visit_StateResultExpr(self, node: StateResultExpr) -> None: + # StateResult is always allowed + pass + + def _check_assign(self, node: Any) -> None: + assert isinstance(node, ast.Assign | ast.AnnAssign | ast.AugAssign) + if UnitaryFlags.Dagger in self.flags: + raise GuppyError(InvalidUnderDagger(node, "Assignment")) + if node.value is not None: + self.visit(node.value) + + def visit_AnnAssign(self, node: ast.AnnAssign) -> None: + self._check_assign(node) + + def visit_Assign(self, node: ast.Assign) -> None: + self._check_assign(node) + + def visit_AugAssign(self, node: ast.AugAssign) -> None: + self._check_assign(node) + + def visit_For(self, node: ast.For) -> None: + if UnitaryFlags.Dagger in self.flags: + raise GuppyError(InvalidUnderDagger(node, "Loop")) + self.generic_visit(node) + + def visit_PlaceNode(self, node: PlaceNode) -> None: + if UnitaryFlags.Dagger in self.flags and contains_subscript(node.place): + raise GuppyError( + UnsupportedError(node, "index access", True, "dagger context") + ) + + +def check_cfg_unitary( + cfg: CheckedCFG[Place], + unitary_flags: UnitaryFlags, +) -> None: + bb_checker = BBUnitaryChecker() + for bb in cfg.bbs: + bb_checker.check(bb, unitary_flags) diff --git a/guppylang-internals/src/guppylang_internals/compiler/modifier_compiler.py b/guppylang-internals/src/guppylang_internals/compiler/modifier_compiler.py index 01827477c..a53223ee7 100644 --- a/guppylang-internals/src/guppylang_internals/compiler/modifier_compiler.py +++ b/guppylang-internals/src/guppylang_internals/compiler/modifier_compiler.py @@ -12,6 +12,7 @@ array_unwrap_elem, array_wrap_elem, ) +from guppylang_internals.definition.function import add_unitarity_metadata from guppylang_internals.nodes import CheckedModifiedBlock, PlaceNode from guppylang_internals.std._internal.compiler.array import ( array_convert_from_std_array, @@ -61,6 +62,7 @@ def compile_modified_block( func_builder = dfg.builder.module_root_builder().define_function( str(modified_block), hugr_ty.input, hugr_ty.output ) + add_unitarity_metadata(func_builder.parent_node, modified_block.ty.unitary_flags) # compile body cfg = compile_cfg(modified_block.cfg, func_builder, func_builder.inputs(), ctx) diff --git a/guppylang-internals/src/guppylang_internals/decorator.py b/guppylang-internals/src/guppylang_internals/decorator.py index 59407a5e1..3de18109c 100644 --- a/guppylang-internals/src/guppylang_internals/decorator.py +++ b/guppylang-internals/src/guppylang_internals/decorator.py @@ -39,6 +39,7 @@ InputFlags, NoneType, NumericType, + UnitaryFlags, ) if TYPE_CHECKING: @@ -75,6 +76,7 @@ def custom_function( higher_order_value: bool = True, name: str = "", signature: FunctionType | None = None, + unitary_flags: UnitaryFlags = UnitaryFlags.NoFlags, ) -> Callable[[Callable[P, T]], GuppyFunctionDefinition[P, T]]: """Decorator to add custom typing or compilation behaviour to function decls. @@ -86,6 +88,8 @@ def custom_function( def dec(f: Callable[P, T]) -> GuppyFunctionDefinition[P, T]: call_checker = checker or DefaultCallChecker() + if signature is not None: + object.__setattr__(signature, "unitary_flags", unitary_flags) func = RawCustomFunctionDef( DefId.fresh(), name or f.__name__, @@ -95,6 +99,7 @@ def dec(f: Callable[P, T]) -> GuppyFunctionDefinition[P, T]: compiler or NotImplementedCallCompiler(), higher_order_value, signature, + unitary_flags, ) DEF_STORE.register_def(func, get_calling_frame()) return GuppyFunctionDefinition(func) @@ -108,6 +113,7 @@ def hugr_op( higher_order_value: bool = True, name: str = "", signature: FunctionType | None = None, + unitary_flags: UnitaryFlags = UnitaryFlags.NoFlags, ) -> Callable[[Callable[P, T]], GuppyFunctionDefinition[P, T]]: """Decorator to annotate function declarations as HUGR ops. @@ -119,7 +125,14 @@ def hugr_op( value. name: The name of the function. """ - return custom_function(OpCompiler(op), checker, higher_order_value, name, signature) + return custom_function( + OpCompiler(op), + checker, + higher_order_value, + name, + signature, + unitary_flags=unitary_flags, + ) def extend_type(defn: TypeDef, return_class: bool = False) -> Callable[[type], type]: diff --git a/guppylang-internals/src/guppylang_internals/definition/custom.py b/guppylang-internals/src/guppylang_internals/definition/custom.py index e96bb3e06..8932bb80b 100644 --- a/guppylang-internals/src/guppylang_internals/definition/custom.py +++ b/guppylang-internals/src/guppylang_internals/definition/custom.py @@ -42,6 +42,7 @@ InputFlags, NoneType, Type, + UnitaryFlags, type_to_row, ) @@ -112,6 +113,8 @@ class RawCustomFunctionDef(ParsableDef): signature: FunctionType | None + unitary_flags: UnitaryFlags = field(default=UnitaryFlags.NoFlags) + description: str = field(default="function", init=False) def parse(self, globals: "Globals", sources: SourceMap) -> "CustomFunctionDef": @@ -134,6 +137,7 @@ def parse(self, globals: "Globals", sources: SourceMap) -> "CustomFunctionDef": raise GuppyError(BodyNotEmptyError(func_ast.body[0], self.name)) sig = self.signature or self._get_signature(func_ast, globals) ty = sig or FunctionType([], NoneType()) + object.__setattr__(ty, "unitary_flags", self.unitary_flags) return CustomFunctionDef( self.id, self.name, diff --git a/guppylang-internals/src/guppylang_internals/definition/declaration.py b/guppylang-internals/src/guppylang_internals/definition/declaration.py index c5620661b..718a8ccae 100644 --- a/guppylang-internals/src/guppylang_internals/definition/declaration.py +++ b/guppylang-internals/src/guppylang_internals/definition/declaration.py @@ -34,7 +34,7 @@ from guppylang_internals.span import SourceMap from guppylang_internals.tys.param import Parameter from guppylang_internals.tys.subst import Inst, Subst -from guppylang_internals.tys.ty import Type +from guppylang_internals.tys.ty import Type, UnitaryFlags @dataclass(frozen=True) @@ -65,10 +65,14 @@ class RawFunctionDecl(ParsableDef): python_func: PyFunc description: str = field(default="function", init=False) + unitary_flags: UnitaryFlags = field(default=UnitaryFlags.NoFlags, kw_only=True) + def parse(self, globals: Globals, sources: SourceMap) -> "CheckedFunctionDecl": """Parses and checks the user-provided signature of the function.""" func_ast, docstring = parse_py_func(self.python_func, sources) - ty = check_signature(func_ast, globals, self.id) + ty = check_signature( + func_ast, globals, self.id, unitary_flags=self.unitary_flags + ) if not has_empty_body(func_ast): raise GuppyError(BodyNotEmptyError(func_ast.body[0], self.name)) # Make sure we won't need monomorphization to compile this declaration diff --git a/guppylang-internals/src/guppylang_internals/definition/function.py b/guppylang-internals/src/guppylang_internals/definition/function.py index cddbc41f9..127f7b435 100644 --- a/guppylang-internals/src/guppylang_internals/definition/function.py +++ b/guppylang-internals/src/guppylang_internals/definition/function.py @@ -43,7 +43,7 @@ from guppylang_internals.nodes import GlobalCall from guppylang_internals.span import SourceMap from guppylang_internals.tys.subst import Inst, Subst -from guppylang_internals.tys.ty import FunctionType, Type, type_to_row +from guppylang_internals.tys.ty import FunctionType, Type, UnitaryFlags, type_to_row if TYPE_CHECKING: from guppylang_internals.tys.param import Parameter @@ -70,6 +70,8 @@ class RawFunctionDef(ParsableDef): description: str = field(default="function", init=False) + unitary_flags: UnitaryFlags = field(default=UnitaryFlags.NoFlags, kw_only=True) + def parse(self, globals: Globals, sources: SourceMap) -> "ParsedFunctionDef": """Parses and checks the user-provided signature of the function.""" func_ast, docstring = parse_py_func(self.python_func, sources) @@ -173,6 +175,7 @@ def monomorphize( func_def = module.module_root_builder().define_function( self.name, hugr_ty.body.input, hugr_ty.body.output, hugr_ty.params ) + add_unitarity_metadata(func_def.parent_node, mono_ty.unitary_flags) return CompiledFunctionDef( self.id, self.name, @@ -300,3 +303,15 @@ def parse_source(source_lines: list[str], line_offset: int) -> tuple[str, ast.AS else: node = ast.parse(source).body[0] return source, node, line_offset + + +def add_unitarity_metadata(node: Node, flags: UnitaryFlags) -> None: + code = 0 + if flags & UnitaryFlags.Dagger: + code |= 1 + if flags & UnitaryFlags.Control: + code |= 2 + if flags & UnitaryFlags.Power: + code |= 4 + + node.metadata["unitary"] = code diff --git a/guppylang-internals/src/guppylang_internals/definition/pytket_circuits.py b/guppylang-internals/src/guppylang_internals/definition/pytket_circuits.py index e6fbf134c..f2c8e7a10 100644 --- a/guppylang-internals/src/guppylang_internals/definition/pytket_circuits.py +++ b/guppylang-internals/src/guppylang_internals/definition/pytket_circuits.py @@ -398,6 +398,7 @@ def _signature_from_circuit( use_arrays: bool = False, ) -> FunctionType: """Helper function for inferring a function signature from a pytket circuit.""" + # May want to set proper unitary flags in the future. try: import pytket diff --git a/guppylang-internals/src/guppylang_internals/tys/errors.py b/guppylang-internals/src/guppylang_internals/tys/errors.py index d0918dfcc..94fc1989c 100644 --- a/guppylang-internals/src/guppylang_internals/tys/errors.py +++ b/guppylang-internals/src/guppylang_internals/tys/errors.py @@ -5,7 +5,7 @@ if TYPE_CHECKING: from guppylang_internals.definition.parameter import ParamDef - from guppylang_internals.tys.ty import Type + from guppylang_internals.tys.ty import Type, UnitaryFlags @dataclass(frozen=True) @@ -182,3 +182,25 @@ class InvalidFlagError(Error): class FlagNotAllowedError(Error): title: ClassVar[str] = "Invalid annotation" span_label: ClassVar[str] = "`@` type annotations are not allowed in this position" + + +@dataclass(frozen=True) +class UnitaryCallError(Error): + title: ClassVar[str] = "Unitary constraint violation" + span_label: ClassVar[str] = ( + "This function cannot be called in a {render_flags} context" + ) + flags: "UnitaryFlags" + + @property + def render_flags(self) -> str: + from guppylang_internals.tys.ty import UnitaryFlags + + if self.flags == UnitaryFlags.Dagger: + return "dagger" + elif self.flags == UnitaryFlags.Control: + return "control" + elif self.flags == UnitaryFlags.Power: + return "power" + else: + return "unitary" diff --git a/guppylang-internals/src/guppylang_internals/tys/qubit.py b/guppylang-internals/src/guppylang_internals/tys/qubit.py index 036f832ec..49a15144e 100644 --- a/guppylang-internals/src/guppylang_internals/tys/qubit.py +++ b/guppylang-internals/src/guppylang_internals/tys/qubit.py @@ -1,8 +1,10 @@ import functools -from typing import cast +from typing import Any, cast from guppylang_internals.definition.ty import TypeDef -from guppylang_internals.tys.ty import Type +from guppylang_internals.tys.arg import TypeArg +from guppylang_internals.tys.common import Visitor +from guppylang_internals.tys.ty import OpaqueType, Type @functools.cache @@ -25,3 +27,36 @@ def is_qubit_ty(ty: Type) -> bool: before qubit types are registered. """ return ty == qubit_ty() + + +class QubitFinder(Visitor): + """Type visitor that checks if a type contains the qubit type.""" + + class FoundFlag(Exception): + pass + + @functools.singledispatchmethod + def visit(self, ty: Any) -> bool: # type: ignore[override] + return False + + @visit.register + def _visit_OpaqueType(self, ty: OpaqueType) -> bool: + if is_qubit_ty(ty): + raise self.FoundFlag + return False + + @visit.register + def _visit_TypeArg(self, arg: TypeArg) -> bool: + arg.ty.visit(self) + return True + + +def contain_qubit_ty(ty: Type) -> bool: + """Checks if the given type contains the qubit type.""" + finder = QubitFinder() + try: + ty.visit(finder) + except QubitFinder.FoundFlag: + return True + else: + return False diff --git a/guppylang-internals/src/guppylang_internals/tys/ty.py b/guppylang-internals/src/guppylang_internals/tys/ty.py index 0ea109be0..b7b06fcdf 100644 --- a/guppylang-internals/src/guppylang_internals/tys/ty.py +++ b/guppylang-internals/src/guppylang_internals/tys/ty.py @@ -382,6 +382,21 @@ class InputFlags(Flag): Comptime = auto() +class UnitaryFlags(Flag): + """Flags that can be set on functions to indicate their unitary properties. + + The flags indicate under which conditions a function can be used + in a unitary context. + """ + + NoFlags = 0 + Control = auto() + Dagger = auto() + Power = auto() + + Unitary = Control | Dagger | Power + + @dataclass(frozen=True) class FuncInput: """A single input of a function type.""" @@ -407,6 +422,8 @@ class FunctionType(ParametrizedTypeBase): intrinsically_droppable: bool = field(default=True, init=True) hugr_bound: ht.TypeBound = field(default=ht.TypeBound.Copyable, init=False) + unitary_flags: UnitaryFlags = field(default=UnitaryFlags.NoFlags, init=True) + def __init__( self, inputs: Sequence[FuncInput], @@ -414,6 +431,7 @@ def __init__( input_names: Sequence[str] | None = None, params: Sequence[Parameter] | None = None, comptime_args: Sequence[ConstArg] | None = None, + unitary_flags: UnitaryFlags = UnitaryFlags.NoFlags, ) -> None: # We need a custom __init__ to set the args args: list[Argument] = [TypeArg(inp.ty) for inp in inputs] @@ -435,6 +453,7 @@ def __init__( object.__setattr__(self, "output", output) object.__setattr__(self, "input_names", input_names or []) object.__setattr__(self, "params", params) + object.__setattr__(self, "unitary_flags", unitary_flags) @property def parametrized(self) -> bool: diff --git a/guppylang/src/guppylang/decorator.py b/guppylang/src/guppylang/decorator.py index c41b63040..bc19063d7 100644 --- a/guppylang/src/guppylang/decorator.py +++ b/guppylang/src/guppylang/decorator.py @@ -49,6 +49,7 @@ FunctionType, NoneType, NumericType, + UnitaryFlags, ) from hugr import ops from hugr import tys as ht @@ -83,10 +84,33 @@ class _Guppy: """Class for the `@guppy` decorator.""" - def __call__(self, f: Callable[P, T]) -> GuppyFunctionDefinition[P, T]: - defn = RawFunctionDef(DefId.fresh(), f.__name__, None, f) - DEF_STORE.register_def(defn, get_calling_frame()) - return GuppyFunctionDefinition(defn) + # TODO (k.hirata): + # + # def __call__(self, f: Callable[P, T]) -> GuppyFunctionDefinition[P, T]: + # defn = RawFunctionDef(DefId.fresh(), f.__name__, None, f) + # DEF_STORE.register_def(defn, get_calling_frame()) + # return GuppyFunctionDefinition(defn) + # + # trying to support both `@guppy` and `@guppy(unitary_flags=...)` styles + def __call__( + self, + f: Callable[P, T] | None = None, + unitary_flags: UnitaryFlags = UnitaryFlags.NoFlags, + ) -> ( + GuppyFunctionDefinition[P, T] + | Callable[[Callable[P, T]], GuppyFunctionDefinition[P, T]] + ): + def register(fn: Callable[P, T]) -> GuppyFunctionDefinition[P, T]: + defn = RawFunctionDef( + DefId.fresh(), fn.__name__, None, fn, unitary_flags=unitary_flags + ) + DEF_STORE.register_def(defn, get_calling_frame()) + return GuppyFunctionDefinition(defn) + + if f is None: + return register + else: + return register(f) def comptime(self, f: Callable[P, T]) -> GuppyFunctionDefinition[P, T]: """Registers a function to be executed at compile-time during Guppy compilation, @@ -221,11 +245,27 @@ def hugr_op( ) -> Callable[[Callable[P, T]], GuppyFunctionDefinition[P, T]]: return hugr_op(op, checker, higher_order_value, name, signature) - def declare(self, f: Callable[P, T]) -> GuppyFunctionDefinition[P, T]: + def declare( + self, + f: Callable[P, T] | None = None, + unitary_flags: UnitaryFlags = UnitaryFlags.NoFlags, + ) -> ( + GuppyFunctionDefinition[P, T] + | Callable[[Callable[P, T]], GuppyFunctionDefinition[P, T]] + ): """Declares a Guppy function without defining it.""" - defn = RawFunctionDecl(DefId.fresh(), f.__name__, None, f) - DEF_STORE.register_def(defn, get_calling_frame()) - return GuppyFunctionDefinition(defn) + + def register(fn: Callable[P, T]) -> GuppyFunctionDefinition[P, T]: + defn = RawFunctionDecl( + DefId.fresh(), fn.__name__, None, fn, unitary_flags=unitary_flags + ) + DEF_STORE.register_def(defn, get_calling_frame()) + return GuppyFunctionDefinition(defn) + + if f is None: + return register + else: + return register(f) def overload( self, *funcs: Any From cb411f826a64999371791436f7ddfb9120e612cf Mon Sep 17 00:00:00 2001 From: hkengo-qtnm Date: Wed, 8 Oct 2025 14:21:12 +0100 Subject: [PATCH 04/13] add flags to tests --- .../src/guppylang_internals/cfg/builder.py | 6 +-- .../definition/function.py | 2 +- .../src/guppylang_internals/nodes.py | 17 ++++++++- guppylang/src/guppylang/decorator.py | 7 ---- .../src/guppylang/std/quantum/__init__.py | 37 ++++++++++--------- .../modifier_errors/captured_var_inout_own.py | 4 +- .../captured_var_inout_reassign.py | 4 +- tests/error/modifier_errors/ctrl_arg_copy.py | 4 +- tests/error/modifier_errors/flag_call.py | 15 ++++++++ .../modifier_errors/flag_dagger_assign.py | 10 +++++ tests/error/modifier_errors/flag_loop.py | 11 ++++++ tests/error/modifier_errors/higher_order.err | 8 ++++ tests/error/modifier_errors/higher_order.py | 22 +++++++++++ 13 files changed, 111 insertions(+), 36 deletions(-) create mode 100644 tests/error/modifier_errors/flag_call.py create mode 100644 tests/error/modifier_errors/flag_dagger_assign.py create mode 100644 tests/error/modifier_errors/flag_loop.py create mode 100644 tests/error/modifier_errors/higher_order.err create mode 100644 tests/error/modifier_errors/higher_order.py diff --git a/guppylang-internals/src/guppylang_internals/cfg/builder.py b/guppylang-internals/src/guppylang_internals/cfg/builder.py index e7d8aa996..8423abcc7 100644 --- a/guppylang-internals/src/guppylang_internals/cfg/builder.py +++ b/guppylang-internals/src/guppylang_internals/cfg/builder.py @@ -92,7 +92,7 @@ def build( variables. """ self.cfg = CFG() - self.unitary_flags = uniraty_flags + self.cfg.unitary_flags = uniraty_flags self.globals = globals final_bb = self.visit_stmts( @@ -308,8 +308,8 @@ def visit_With(self, node: ast.With, bb: BB, jumps: Jumps) -> BB | None: modifier = self._handle_withitem(item) new_node.push_modifier(modifier) - unitary_flags = UnitaryFlags.NoFlags - # TODO (k.hirata): set unitary_flags properly + # TODO: its parent's flags need to be added too + unitary_flags = new_node.add_flags(UnitaryFlags.NoFlags) object.__setattr__(cfg, "unitary_flags", unitary_flags) set_location_from(new_node, node) diff --git a/guppylang-internals/src/guppylang_internals/definition/function.py b/guppylang-internals/src/guppylang_internals/definition/function.py index 127f7b435..d52e04ae0 100644 --- a/guppylang-internals/src/guppylang_internals/definition/function.py +++ b/guppylang-internals/src/guppylang_internals/definition/function.py @@ -175,7 +175,7 @@ def monomorphize( func_def = module.module_root_builder().define_function( self.name, hugr_ty.body.input, hugr_ty.body.output, hugr_ty.params ) - add_unitarity_metadata(func_def.parent_node, mono_ty.unitary_flags) + add_unitarity_metadata(func_def.parent_node, self.ty.unitary_flags) return CompiledFunctionDef( self.id, self.name, diff --git a/guppylang-internals/src/guppylang_internals/nodes.py b/guppylang-internals/src/guppylang_internals/nodes.py index e5f1dc2fc..96d543140 100644 --- a/guppylang-internals/src/guppylang_internals/nodes.py +++ b/guppylang-internals/src/guppylang_internals/nodes.py @@ -9,7 +9,13 @@ from guppylang_internals.span import Span, to_span from guppylang_internals.tys.const import Const from guppylang_internals.tys.subst import Inst -from guppylang_internals.tys.ty import FunctionType, StructType, TupleType, Type +from guppylang_internals.tys.ty import ( + FunctionType, + StructType, + TupleType, + Type, + UnitaryFlags, +) if TYPE_CHECKING: from guppylang_internals.cfg.cfg import CFG @@ -500,6 +506,15 @@ def push_modifier(self, modifier: Modifier) -> None: else: raise TypeError(f"Unknown modifier: {modifier}") + def add_flags(self, flags: UnitaryFlags) -> UnitaryFlags: + if self.is_dagger(): + flags |= UnitaryFlags.Dagger + if self.is_control(): + flags |= UnitaryFlags.Control + if self.is_power(): + flags |= UnitaryFlags.Power + return flags + class CheckedModifiedBlock(ast.With): def_id: "DefId" diff --git a/guppylang/src/guppylang/decorator.py b/guppylang/src/guppylang/decorator.py index bc19063d7..447c5bdf5 100644 --- a/guppylang/src/guppylang/decorator.py +++ b/guppylang/src/guppylang/decorator.py @@ -84,13 +84,6 @@ class _Guppy: """Class for the `@guppy` decorator.""" - # TODO (k.hirata): - # - # def __call__(self, f: Callable[P, T]) -> GuppyFunctionDefinition[P, T]: - # defn = RawFunctionDef(DefId.fresh(), f.__name__, None, f) - # DEF_STORE.register_def(defn, get_calling_frame()) - # return GuppyFunctionDefinition(defn) - # # trying to support both `@guppy` and `@guppy(unitary_flags=...)` styles def __call__( self, diff --git a/guppylang/src/guppylang/std/quantum/__init__.py b/guppylang/src/guppylang/std/quantum/__init__.py index ed2818ecd..33018e685 100644 --- a/guppylang/src/guppylang/std/quantum/__init__.py +++ b/guppylang/src/guppylang/std/quantum/__init__.py @@ -10,6 +10,7 @@ RotationCompiler, ) from guppylang_internals.std._internal.util import quantum_op +from guppylang_internals.tys.ty import UnitaryFlags from hugr import tys as ht from guppylang import guppy @@ -48,7 +49,7 @@ def maybe_qubit() -> Option[qubit]: if allocation succeeds or `nothing` if it fails.""" -@hugr_op(quantum_op("H")) +@hugr_op(quantum_op("H"), unitary_flags=UnitaryFlags.Unitary) @no_type_check def h(q: qubit) -> None: r"""Hadamard gate command @@ -62,7 +63,7 @@ def h(q: qubit) -> None: """ -@hugr_op(quantum_op("CZ")) +@hugr_op(quantum_op("CZ"), unitary_flags=UnitaryFlags.Unitary) @no_type_check def cz(control: qubit, target: qubit) -> None: r"""Controlled-Z gate command. @@ -82,7 +83,7 @@ def cz(control: qubit, target: qubit) -> None: """ -@hugr_op(quantum_op("CY")) +@hugr_op(quantum_op("CY"), unitary_flags=UnitaryFlags.Unitary) @no_type_check def cy(control: qubit, target: qubit) -> None: r"""Controlled-Y gate command. @@ -102,7 +103,7 @@ def cy(control: qubit, target: qubit) -> None: """ -@hugr_op(quantum_op("CX")) +@hugr_op(quantum_op("CX"), unitary_flags=UnitaryFlags.Unitary) @no_type_check def cx(control: qubit, target: qubit) -> None: r"""Controlled-X gate command. @@ -122,7 +123,7 @@ def cx(control: qubit, target: qubit) -> None: """ -@hugr_op(quantum_op("T")) +@hugr_op(quantum_op("T"), unitary_flags=UnitaryFlags.Unitary) @no_type_check def t(q: qubit) -> None: r"""T gate. @@ -137,7 +138,7 @@ def t(q: qubit) -> None: """ -@hugr_op(quantum_op("S")) +@hugr_op(quantum_op("S"), unitary_flags=UnitaryFlags.Unitary) @no_type_check def s(q: qubit) -> None: r"""S gate. @@ -152,7 +153,7 @@ def s(q: qubit) -> None: """ -@hugr_op(quantum_op("V")) +@hugr_op(quantum_op("V"), unitary_flags=UnitaryFlags.Unitary) @no_type_check def v(q: qubit) -> None: r"""V gate. @@ -167,7 +168,7 @@ def v(q: qubit) -> None: """ -@hugr_op(quantum_op("X")) +@hugr_op(quantum_op("X"), unitary_flags=UnitaryFlags.Unitary) @no_type_check def x(q: qubit) -> None: r"""X gate. @@ -182,7 +183,7 @@ def x(q: qubit) -> None: """ -@hugr_op(quantum_op("Y")) +@hugr_op(quantum_op("Y"), unitary_flags=UnitaryFlags.Unitary) @no_type_check def y(q: qubit) -> None: r"""Y gate. @@ -197,7 +198,7 @@ def y(q: qubit) -> None: """ -@hugr_op(quantum_op("Z")) +@hugr_op(quantum_op("Z"), unitary_flags=UnitaryFlags.Unitary) @no_type_check def z(q: qubit) -> None: r"""Z gate. @@ -212,7 +213,7 @@ def z(q: qubit) -> None: """ -@hugr_op(quantum_op("Tdg")) +@hugr_op(quantum_op("Tdg"), unitary_flags=UnitaryFlags.Unitary) @no_type_check def tdg(q: qubit) -> None: r"""Tdg gate. @@ -227,7 +228,7 @@ def tdg(q: qubit) -> None: """ -@hugr_op(quantum_op("Sdg")) +@hugr_op(quantum_op("Sdg"), unitary_flags=UnitaryFlags.Unitary) @no_type_check def sdg(q: qubit) -> None: r"""Sdg gate. @@ -242,7 +243,7 @@ def sdg(q: qubit) -> None: """ -@hugr_op(quantum_op("Vdg")) +@hugr_op(quantum_op("Vdg"), unitary_flags=UnitaryFlags.Unitary) @no_type_check def vdg(q: qubit) -> None: r"""Vdg gate. @@ -257,7 +258,7 @@ def vdg(q: qubit) -> None: """ -@custom_function(RotationCompiler("Rz")) +@custom_function(RotationCompiler("Rz"), unitary_flags=UnitaryFlags.Unitary) @no_type_check def rz(q: qubit, angle: angle) -> None: r"""Rz gate. @@ -273,7 +274,7 @@ def rz(q: qubit, angle: angle) -> None: """ -@custom_function(RotationCompiler("Rx")) +@custom_function(RotationCompiler("Rx"), unitary_flags=UnitaryFlags.Unitary) @no_type_check def rx(q: qubit, angle: angle) -> None: r"""Rx gate. @@ -288,7 +289,7 @@ def rx(q: qubit, angle: angle) -> None: """ -@custom_function(RotationCompiler("Ry")) +@custom_function(RotationCompiler("Ry"), unitary_flags=UnitaryFlags.Unitary) @no_type_check def ry(q: qubit, angle: angle) -> None: r"""Ry gate. @@ -303,7 +304,7 @@ def ry(q: qubit, angle: angle) -> None: """ -@custom_function(RotationCompiler("CRz")) +@custom_function(RotationCompiler("CRz"), unitary_flags=UnitaryFlags.Unitary) @no_type_check def crz(control: qubit, target: qubit, angle: angle) -> None: r"""Controlled-Rz gate command. @@ -323,7 +324,7 @@ def crz(control: qubit, target: qubit, angle: angle) -> None: """ -@hugr_op(quantum_op("Toffoli")) +@hugr_op(quantum_op("Toffoli"), unitary_flags=UnitaryFlags.Unitary) @no_type_check def toffoli(control1: qubit, control2: qubit, target: qubit) -> None: r"""A Toffoli gate command. Also sometimes known as a CCX gate. diff --git a/tests/error/modifier_errors/captured_var_inout_own.py b/tests/error/modifier_errors/captured_var_inout_own.py index 0a06b9170..b136b86e8 100644 --- a/tests/error/modifier_errors/captured_var_inout_own.py +++ b/tests/error/modifier_errors/captured_var_inout_own.py @@ -1,8 +1,8 @@ from guppylang.decorator import guppy -from guppylang.std.quantum import qubit, owned +from guppylang.std.quantum import qubit, owned, UnitaryFlags -@guppy.declare +@guppy.declare(unitary_flags=UnitaryFlags.Dagger) def discard(q: qubit @ owned) -> None: ... diff --git a/tests/error/modifier_errors/captured_var_inout_reassign.py b/tests/error/modifier_errors/captured_var_inout_reassign.py index 81b02941f..529ce22a4 100644 --- a/tests/error/modifier_errors/captured_var_inout_reassign.py +++ b/tests/error/modifier_errors/captured_var_inout_reassign.py @@ -1,8 +1,8 @@ from guppylang.decorator import guppy -from guppylang.std.quantum import qubit +from guppylang.std.quantum import qubit, UnitaryFlags -@guppy.declare +@guppy.declare(unitary_flags=UnitaryFlags.Dagger) def use(q: qubit) -> None: ... diff --git a/tests/error/modifier_errors/ctrl_arg_copy.py b/tests/error/modifier_errors/ctrl_arg_copy.py index 1d94d6bf2..0ea99862a 100644 --- a/tests/error/modifier_errors/ctrl_arg_copy.py +++ b/tests/error/modifier_errors/ctrl_arg_copy.py @@ -1,12 +1,12 @@ from guppylang.decorator import guppy -from guppylang.std.quantum import qubit, owned +from guppylang.std.quantum import qubit, owned, UnitaryFlags @guppy.declare def discard(q: qubit @ owned) -> None: ... -@guppy.declare +@guppy.declare(unitary_flags=UnitaryFlags.Control) def use(q: qubit) -> None: ... diff --git a/tests/error/modifier_errors/flag_call.py b/tests/error/modifier_errors/flag_call.py new file mode 100644 index 000000000..3cee648d1 --- /dev/null +++ b/tests/error/modifier_errors/flag_call.py @@ -0,0 +1,15 @@ +from guppylang_internals.tys.ty import UnitaryFlags +from guppylang.decorator import guppy +from guppylang.std.quantum import qubit + + +@guppy.declare +def foo(x: qubit) -> None: ... + + +@guppy(unitary_flags=UnitaryFlags.Dagger) +def test(x: qubit) -> None: + foo(x) + + +test.compile() diff --git a/tests/error/modifier_errors/flag_dagger_assign.py b/tests/error/modifier_errors/flag_dagger_assign.py new file mode 100644 index 000000000..ebda41548 --- /dev/null +++ b/tests/error/modifier_errors/flag_dagger_assign.py @@ -0,0 +1,10 @@ +from guppylang.decorator import guppy +from guppylang.std.quantum import UnitaryFlags + + +@guppy(unitary_flags=UnitaryFlags.Dagger) +def test() -> None: + x = 3 + + +test.compile() diff --git a/tests/error/modifier_errors/flag_loop.py b/tests/error/modifier_errors/flag_loop.py new file mode 100644 index 000000000..17f48a9da --- /dev/null +++ b/tests/error/modifier_errors/flag_loop.py @@ -0,0 +1,11 @@ +from guppylang.decorator import guppy +from guppylang.std.quantum import UnitaryFlags + + +@guppy(unitary_flags=UnitaryFlags.Dagger) +def test() -> None: + for _ in range(3): + pass + + +test.compile() diff --git a/tests/error/modifier_errors/higher_order.err b/tests/error/modifier_errors/higher_order.err new file mode 100644 index 000000000..e8432761f --- /dev/null +++ b/tests/error/modifier_errors/higher_order.err @@ -0,0 +1,8 @@ +Error: Unitary constraint violation (at $FILE:18:8) + | +16 | q = qubit() +17 | with dagger: +18 | test_ho(h, q) + | ^^^^^^^^^^^^^ This function cannot be called in a dagger context + +Guppy compilation failed due to 1 previous error diff --git a/tests/error/modifier_errors/higher_order.py b/tests/error/modifier_errors/higher_order.py new file mode 100644 index 000000000..51c826324 --- /dev/null +++ b/tests/error/modifier_errors/higher_order.py @@ -0,0 +1,22 @@ +from guppylang_internals.tys.ty import UnitaryFlags +from guppylang.decorator import guppy +from guppylang.std.quantum import qubit, h, discard +from collections.abc import Callable + + +# The flag is required to be used in dagger context +@guppy(unitary_flags=UnitaryFlags.Dagger) +def test_ho(f: Callable[[qubit], None], q: qubit) -> None: + # There is no way to use specify flags for f + f(q) + + +@guppy +def test() -> None: + q = qubit() + with dagger: + test_ho(h, q) + discard(q) + + +test.compile() From fa13e8d0de00b0ad917130072be2335ba833f71f Mon Sep 17 00:00:00 2001 From: hkengo-qtnm Date: Wed, 8 Oct 2025 15:06:59 +0100 Subject: [PATCH 05/13] typo --- guppylang-internals/src/guppylang_internals/cfg/builder.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/guppylang-internals/src/guppylang_internals/cfg/builder.py b/guppylang-internals/src/guppylang_internals/cfg/builder.py index 8423abcc7..eaa410a9d 100644 --- a/guppylang-internals/src/guppylang_internals/cfg/builder.py +++ b/guppylang-internals/src/guppylang_internals/cfg/builder.py @@ -83,7 +83,7 @@ def build( nodes: list[ast.stmt], returns_none: bool, globals: Globals, - uniraty_flags: UnitaryFlags = UnitaryFlags.NoFlags, + unitary_flags: UnitaryFlags = UnitaryFlags.NoFlags, ) -> CFG: """Builds a CFG from a list of ast nodes. @@ -92,7 +92,7 @@ def build( variables. """ self.cfg = CFG() - self.cfg.unitary_flags = uniraty_flags + self.cfg.unitary_flags = unitary_flags self.globals = globals final_bb = self.visit_stmts( From 21d2c33fdd2f0fba65dfdd94131853e171a84071 Mon Sep 17 00:00:00 2001 From: hkengo-qtnm Date: Wed, 8 Oct 2025 18:00:21 +0100 Subject: [PATCH 06/13] fixed some bugs --- .../src/guppylang_internals/cfg/builder.py | 2 +- .../checker/cfg_checker.py | 33 ++++++- .../checker/func_checker.py | 1 - .../checker/linearity_checker.py | 1 + .../checker/modifier_checker.py | 4 +- .../checker/unitary_checker.py | 5 - .../definition/function.py | 2 +- .../src/guppylang_internals/nodes.py | 3 +- guppylang/src/guppylang/decorator.py | 93 +++++++++++-------- .../captured_var_inout_own.err | 10 +- .../modifier_errors/captured_var_inout_own.py | 3 +- .../captured_var_inout_reassign.err | 8 +- .../captured_var_inout_reassign.py | 3 +- tests/error/modifier_errors/ctrl_arg_copy.err | 10 +- tests/error/modifier_errors/ctrl_arg_copy.py | 3 +- tests/error/modifier_errors/flag_call.err | 8 ++ tests/error/modifier_errors/flag_call.py | 8 +- .../modifier_errors/flag_dagger_assign.err | 8 ++ .../modifier_errors/flag_dagger_assign.py | 3 +- tests/error/modifier_errors/flag_loop.err | 8 ++ tests/error/modifier_errors/flag_loop.py | 6 +- tests/error/modifier_errors/higher_order.err | 10 +- tests/error/modifier_errors/higher_order.py | 3 +- 23 files changed, 151 insertions(+), 84 deletions(-) create mode 100644 tests/error/modifier_errors/flag_call.err create mode 100644 tests/error/modifier_errors/flag_dagger_assign.err create mode 100644 tests/error/modifier_errors/flag_loop.err diff --git a/guppylang-internals/src/guppylang_internals/cfg/builder.py b/guppylang-internals/src/guppylang_internals/cfg/builder.py index eaa410a9d..a8c10ae66 100644 --- a/guppylang-internals/src/guppylang_internals/cfg/builder.py +++ b/guppylang-internals/src/guppylang_internals/cfg/builder.py @@ -309,7 +309,7 @@ def visit_With(self, node: ast.With, bb: BB, jumps: Jumps) -> BB | None: new_node.push_modifier(modifier) # TODO: its parent's flags need to be added too - unitary_flags = new_node.add_flags(UnitaryFlags.NoFlags) + unitary_flags = new_node.flags() object.__setattr__(cfg, "unitary_flags", unitary_flags) set_location_from(new_node, node) diff --git a/guppylang-internals/src/guppylang_internals/checker/cfg_checker.py b/guppylang-internals/src/guppylang_internals/checker/cfg_checker.py index 5edee5412..7fb96fe0c 100644 --- a/guppylang-internals/src/guppylang_internals/checker/cfg_checker.py +++ b/guppylang-internals/src/guppylang_internals/checker/cfg_checker.py @@ -10,7 +10,7 @@ from dataclasses import dataclass, field from typing import ClassVar, Generic, TypeVar -from guppylang_internals.ast_util import line_col +from guppylang_internals.ast_util import line_col, loop_in_ast from guppylang_internals.cfg.bb import BB from guppylang_internals.cfg.cfg import CFG, BaseCFG from guppylang_internals.checker.core import ( @@ -21,12 +21,13 @@ V, Variable, ) +from guppylang_internals.checker.errors.generic import InvalidUnderDagger from guppylang_internals.checker.expr_checker import ExprSynthesizer, to_bool from guppylang_internals.checker.stmt_checker import StmtChecker from guppylang_internals.diagnostic import Error, Note from guppylang_internals.error import GuppyError from guppylang_internals.tys.param import Parameter -from guppylang_internals.tys.ty import InputFlags, Type +from guppylang_internals.tys.ty import InputFlags, Type, UnitaryFlags Row = Sequence[V] @@ -88,6 +89,8 @@ def check_cfg( inout_vars = [v for v in inputs if InputFlags.Inout in v.flags] cfg.analyze(ass_before, ass_before, [v.name for v in inout_vars]) + check_invalid_in_dagger(cfg) + # We start by compiling the entry BB checked_cfg: CheckedCFG[Variable] = CheckedCFG([v.ty for v in inputs], return_ty) checked_cfg.entry_bb = check_bb( @@ -149,6 +152,7 @@ def check_cfg( checked_cfg.maybe_ass_before = { compiled[bb]: cfg.maybe_ass_before[bb] for bb in required_bbs } + checked_cfg.unitary_flags = cfg.unitary_flags # Finally, run the linearity check from guppylang_internals.checker.linearity_checker import check_cfg_linearity @@ -352,3 +356,28 @@ def reverse_enumerate(xs: list[T]) -> Iterator[tuple[int, T]]: """ for i in range(len(xs) - 1, -1, -1): yield i, xs[i] + +# TODO (k.hirata): This function is supposed to detect loops and assignments in daggered blocks. +# However, this has to be called much earlier since the builder already deconstructs loops +# to CFGs. +# def check_invalid_in_dagger(cfg: CFG) -> None: +# """Check that there are no invalid constructs in a daggered CFG. +# This checker checks the case the UnitaryFlags is given by +# annotation (i.e., not inferred from `with dagger:`). +# """ +# if UnitaryFlags.Dagger not in cfg.unitary_flags: +# return + +# for cfg_bb in cfg.bbs: +# for stmt in cfg_bb.statements: +# loops = loop_in_ast(stmt) +# if len(loops) != 0: +# loop = next(iter(loops)) +# err = InvalidUnderDagger(loop, "Loop") +# raise GuppyError(err) +# # Note: sub-diagnostic for dagger context is not available here + +# if cfg_bb.vars.assigned: +# _, v = next(iter(cfg_bb.vars.assigned.items())) +# err = InvalidUnderDagger(v, "Assignment") +# raise GuppyError(err) diff --git a/guppylang-internals/src/guppylang_internals/checker/func_checker.py b/guppylang-internals/src/guppylang_internals/checker/func_checker.py index ea848f678..dc1e2d9a4 100644 --- a/guppylang-internals/src/guppylang_internals/checker/func_checker.py +++ b/guppylang-internals/src/guppylang_internals/checker/func_checker.py @@ -256,7 +256,6 @@ def check_signature( passed. This will be used to check or infer the type annotation for the `self` argument. """ - # TODO:(k.hirata) unitary_flags if len(func_def.args.posonlyargs) != 0: raise GuppyError( UnsupportedError(func_def.args.posonlyargs[0], "Positional-only parameters") diff --git a/guppylang-internals/src/guppylang_internals/checker/linearity_checker.py b/guppylang-internals/src/guppylang_internals/checker/linearity_checker.py index 6f6987ce7..96e326ff5 100644 --- a/guppylang-internals/src/guppylang_internals/checker/linearity_checker.py +++ b/guppylang-internals/src/guppylang_internals/checker/linearity_checker.py @@ -880,6 +880,7 @@ def live_places_row( result_cfg.maybe_ass_before = { checked[bb]: cfg.maybe_ass_before[bb] for bb in cfg.bbs } + result_cfg.unitary_flags = cfg.unitary_flags for bb in cfg.bbs: checked[bb].predecessors = [checked[pred] for pred in bb.predecessors] checked[bb].successors = [checked[succ] for succ in bb.successors] diff --git a/guppylang-internals/src/guppylang_internals/checker/modifier_checker.py b/guppylang-internals/src/guppylang_internals/checker/modifier_checker.py index e5033e9e9..075211ffe 100644 --- a/guppylang-internals/src/guppylang_internals/checker/modifier_checker.py +++ b/guppylang-internals/src/guppylang_internals/checker/modifier_checker.py @@ -4,6 +4,7 @@ from guppylang_internals.ast_util import loop_in_ast, with_loc from guppylang_internals.cfg.bb import BB +from guppylang_internals.cfg.cfg import CFG from guppylang_internals.checker.cfg_checker import check_cfg from guppylang_internals.checker.core import Context, Variable from guppylang_internals.checker.errors.generic import InvalidUnderDagger @@ -100,8 +101,7 @@ def check_modified_block_signature( ) -> FunctionType: """Check and create the signature of a function definition for a body of a `With` block.""" - # TODO (k.hirata): set unitary flags - unitary_flags = UnitaryFlags.NoFlags + unitary_flags = modified_block.flags() func_ty = FunctionType( [ diff --git a/guppylang-internals/src/guppylang_internals/checker/unitary_checker.py b/guppylang-internals/src/guppylang_internals/checker/unitary_checker.py index 8dc8b7ca1..46e737dc3 100644 --- a/guppylang-internals/src/guppylang_internals/checker/unitary_checker.py +++ b/guppylang-internals/src/guppylang_internals/checker/unitary_checker.py @@ -93,11 +93,6 @@ def visit_Assign(self, node: ast.Assign) -> None: def visit_AugAssign(self, node: ast.AugAssign) -> None: self._check_assign(node) - def visit_For(self, node: ast.For) -> None: - if UnitaryFlags.Dagger in self.flags: - raise GuppyError(InvalidUnderDagger(node, "Loop")) - self.generic_visit(node) - def visit_PlaceNode(self, node: PlaceNode) -> None: if UnitaryFlags.Dagger in self.flags and contains_subscript(node.place): raise GuppyError( diff --git a/guppylang-internals/src/guppylang_internals/definition/function.py b/guppylang-internals/src/guppylang_internals/definition/function.py index d52e04ae0..37090b4bd 100644 --- a/guppylang-internals/src/guppylang_internals/definition/function.py +++ b/guppylang-internals/src/guppylang_internals/definition/function.py @@ -75,7 +75,7 @@ class RawFunctionDef(ParsableDef): def parse(self, globals: Globals, sources: SourceMap) -> "ParsedFunctionDef": """Parses and checks the user-provided signature of the function.""" func_ast, docstring = parse_py_func(self.python_func, sources) - ty = check_signature(func_ast, globals, self.id) + ty = check_signature(func_ast, globals, self.id, unitary_flags=self.unitary_flags) return ParsedFunctionDef(self.id, self.name, func_ast, ty, docstring) diff --git a/guppylang-internals/src/guppylang_internals/nodes.py b/guppylang-internals/src/guppylang_internals/nodes.py index 96d543140..bfd7923e7 100644 --- a/guppylang-internals/src/guppylang_internals/nodes.py +++ b/guppylang-internals/src/guppylang_internals/nodes.py @@ -506,7 +506,8 @@ def push_modifier(self, modifier: Modifier) -> None: else: raise TypeError(f"Unknown modifier: {modifier}") - def add_flags(self, flags: UnitaryFlags) -> UnitaryFlags: + def flags(self) -> UnitaryFlags: + flags = UnitaryFlags.NoFlags if self.is_dagger(): flags |= UnitaryFlags.Dagger if self.is_control(): diff --git a/guppylang/src/guppylang/decorator.py b/guppylang/src/guppylang/decorator.py index 447c5bdf5..84bbbba71 100644 --- a/guppylang/src/guppylang/decorator.py +++ b/guppylang/src/guppylang/decorator.py @@ -4,6 +4,7 @@ from collections.abc import Callable, Sequence from types import FrameType from typing import Any, ParamSpec, TypeVar, cast +from dataclasses import replace from guppylang_internals.ast_util import annotate_location from guppylang_internals.compiler.core import ( @@ -84,26 +85,10 @@ class _Guppy: """Class for the `@guppy` decorator.""" - # trying to support both `@guppy` and `@guppy(unitary_flags=...)` styles - def __call__( - self, - f: Callable[P, T] | None = None, - unitary_flags: UnitaryFlags = UnitaryFlags.NoFlags, - ) -> ( - GuppyFunctionDefinition[P, T] - | Callable[[Callable[P, T]], GuppyFunctionDefinition[P, T]] - ): - def register(fn: Callable[P, T]) -> GuppyFunctionDefinition[P, T]: - defn = RawFunctionDef( - DefId.fresh(), fn.__name__, None, fn, unitary_flags=unitary_flags - ) - DEF_STORE.register_def(defn, get_calling_frame()) - return GuppyFunctionDefinition(defn) - - if f is None: - return register - else: - return register(f) + def __call__(self, f: Callable[P, T]) -> GuppyFunctionDefinition[P, T]: + defn = RawFunctionDef(DefId.fresh(), f.__name__, None, f) + DEF_STORE.register_def(defn, get_calling_frame()) + return GuppyFunctionDefinition(defn) def comptime(self, f: Callable[P, T]) -> GuppyFunctionDefinition[P, T]: """Registers a function to be executed at compile-time during Guppy compilation, @@ -185,7 +170,7 @@ def type_var( .. code-block:: python from guppylang import guppy - T = guppy.type_var("T") + T = guppy.ty @guppy def identity(x: T) -> T: @@ -238,27 +223,11 @@ def hugr_op( ) -> Callable[[Callable[P, T]], GuppyFunctionDefinition[P, T]]: return hugr_op(op, checker, higher_order_value, name, signature) - def declare( - self, - f: Callable[P, T] | None = None, - unitary_flags: UnitaryFlags = UnitaryFlags.NoFlags, - ) -> ( - GuppyFunctionDefinition[P, T] - | Callable[[Callable[P, T]], GuppyFunctionDefinition[P, T]] - ): + def declare(self, f: Callable[P, T]) -> GuppyFunctionDefinition[P, T]: """Declares a Guppy function without defining it.""" - - def register(fn: Callable[P, T]) -> GuppyFunctionDefinition[P, T]: - defn = RawFunctionDecl( - DefId.fresh(), fn.__name__, None, fn, unitary_flags=unitary_flags - ) - DEF_STORE.register_def(defn, get_calling_frame()) - return GuppyFunctionDefinition(defn) - - if f is None: - return register - else: - return register(f) + defn = RawFunctionDecl(DefId.fresh(), f.__name__, None, f) + DEF_STORE.register_def(defn, get_calling_frame()) + return GuppyFunctionDefinition(defn) def overload( self, *funcs: Any @@ -480,6 +449,48 @@ def foo(default_reg: array[qubit, 2], DEF_STORE.register_def(defn, get_calling_frame()) return GuppyFunctionDefinition(defn) + def with_unitary_flags( + self, flags: UnitaryFlags + ) -> Callable[[GuppyFunctionDefinition[P, T]], GuppyFunctionDefinition[P, T]]: + """Wrap a Guppy function with specific unitarity annotations. + + .. code-block:: python + + from guppylang import guppy + from guppylang.std.quantum import qubit, h, UnitaryFlags + + @guppy.with_unitary_flags(UnitaryFlags.Unitary) + @guppy + def apply_h(q: qubit) -> None: + h(q) + """ + + def decorator( + func: GuppyFunctionDefinition[P, T] + ) -> GuppyFunctionDefinition[P, T]: + if not isinstance(func, GuppyFunctionDefinition): + raise TypeError( + "@guppy.with_unitary_flags must be applied above @guppy" + ) + + wrapped = func.wrapped + # In future we may want to support other function-like definitions here + # if not isinstance(wrapped, AnyRawFunctionDef): + if not isinstance(wrapped, RawFunctionDef | RawCustomFunctionDef | RawFunctionDecl): + raise TypeError( + f"Object `{func}` does not have a unitarity annotation" + ) + + if wrapped.unitary_flags == flags: + return func + + updated = replace(wrapped, unitary_flags=flags) + DEF_STORE.raw_defs[updated.id] = updated + return GuppyFunctionDefinition(updated) + + return decorator + + def _parse_expr_string(ty_str: str, parse_err: str, sources: SourceMap) -> ast.expr: """Helper function to parse expressions that are provided as strings. diff --git a/tests/error/modifier_errors/captured_var_inout_own.err b/tests/error/modifier_errors/captured_var_inout_own.err index ee368c4b1..7b4cd6696 100644 --- a/tests/error/modifier_errors/captured_var_inout_own.err +++ b/tests/error/modifier_errors/captured_var_inout_own.err @@ -1,12 +1,12 @@ -Error: Not owned (at $FILE:14:16) +Error: Not owned (at $FILE:15:16) | -12 | a = qubit() -13 | with dagger: -14 | discard(a) +13 | a = qubit() +14 | with dagger: +15 | discard(a) | ^ Function `discard` wants to take ownership of this argument, | but `__modified__()` doesn't own `a` | -12 | a = qubit() +13 | a = qubit() | - Argument `a` is only borrowed. Consider taking ownership: | `a: qubit @owned` diff --git a/tests/error/modifier_errors/captured_var_inout_own.py b/tests/error/modifier_errors/captured_var_inout_own.py index b136b86e8..76a68ff57 100644 --- a/tests/error/modifier_errors/captured_var_inout_own.py +++ b/tests/error/modifier_errors/captured_var_inout_own.py @@ -2,7 +2,8 @@ from guppylang.std.quantum import qubit, owned, UnitaryFlags -@guppy.declare(unitary_flags=UnitaryFlags.Dagger) +@guppy.with_unitary_flags(UnitaryFlags.Dagger) +@guppy.declare def discard(q: qubit @ owned) -> None: ... diff --git a/tests/error/modifier_errors/captured_var_inout_reassign.err b/tests/error/modifier_errors/captured_var_inout_reassign.err index 3a8ab0040..7acefcffc 100644 --- a/tests/error/modifier_errors/captured_var_inout_reassign.err +++ b/tests/error/modifier_errors/captured_var_inout_reassign.err @@ -1,8 +1,8 @@ -Error: Drop violation (at $FILE:11:4) +Error: Drop violation (at $FILE:12:4) | - 9 | @guppy -10 | def test() -> None: -11 | a = qubit() +10 | @guppy +11 | def test() -> None: +12 | a = qubit() | ^ Variable `a` with non-droppable type `qubit` is leaked Help: Make sure that `a` is consumed or returned to avoid the leak diff --git a/tests/error/modifier_errors/captured_var_inout_reassign.py b/tests/error/modifier_errors/captured_var_inout_reassign.py index 529ce22a4..dde19c6a4 100644 --- a/tests/error/modifier_errors/captured_var_inout_reassign.py +++ b/tests/error/modifier_errors/captured_var_inout_reassign.py @@ -2,7 +2,8 @@ from guppylang.std.quantum import qubit, UnitaryFlags -@guppy.declare(unitary_flags=UnitaryFlags.Dagger) +@guppy.with_unitary_flags(UnitaryFlags.Dagger) +@guppy.declare def use(q: qubit) -> None: ... diff --git a/tests/error/modifier_errors/ctrl_arg_copy.err b/tests/error/modifier_errors/ctrl_arg_copy.err index 54acc0b64..d85cfc3e1 100644 --- a/tests/error/modifier_errors/ctrl_arg_copy.err +++ b/tests/error/modifier_errors/ctrl_arg_copy.err @@ -1,12 +1,12 @@ -Error: Copy violation (at $FILE:17:12) +Error: Copy violation (at $FILE:18:12) | -15 | q = qubit() -16 | with control(q): -17 | use(q) +16 | q = qubit() +17 | with control(q): +18 | use(q) | ^ Variable `q` with non-copyable type `qubit` cannot be | borrowed ... | -16 | with control(q): +17 | with control(q): | - since it was already borrowed here Guppy compilation failed due to 1 previous error diff --git a/tests/error/modifier_errors/ctrl_arg_copy.py b/tests/error/modifier_errors/ctrl_arg_copy.py index 0ea99862a..2acf3f8a6 100644 --- a/tests/error/modifier_errors/ctrl_arg_copy.py +++ b/tests/error/modifier_errors/ctrl_arg_copy.py @@ -6,7 +6,8 @@ def discard(q: qubit @ owned) -> None: ... -@guppy.declare(unitary_flags=UnitaryFlags.Control) +@guppy.with_unitary_flags(UnitaryFlags.Control) +@guppy.declare def use(q: qubit) -> None: ... diff --git a/tests/error/modifier_errors/flag_call.err b/tests/error/modifier_errors/flag_call.err new file mode 100644 index 000000000..83562dc03 --- /dev/null +++ b/tests/error/modifier_errors/flag_call.err @@ -0,0 +1,8 @@ +Error: Unitary constraint violation (at $FILE:12:4) + | +10 | @guppy +11 | def test(x: qubit) -> None: +12 | foo(x) + | ^^^^^^ This function cannot be called in a dagger context + +Guppy compilation failed due to 1 previous error diff --git a/tests/error/modifier_errors/flag_call.py b/tests/error/modifier_errors/flag_call.py index 3cee648d1..fffbf1dae 100644 --- a/tests/error/modifier_errors/flag_call.py +++ b/tests/error/modifier_errors/flag_call.py @@ -1,15 +1,15 @@ -from guppylang_internals.tys.ty import UnitaryFlags from guppylang.decorator import guppy -from guppylang.std.quantum import qubit +from guppylang.std.quantum import qubit, UnitaryFlags @guppy.declare def foo(x: qubit) -> None: ... -@guppy(unitary_flags=UnitaryFlags.Dagger) +@guppy.with_unitary_flags(UnitaryFlags.Dagger) +@guppy def test(x: qubit) -> None: foo(x) -test.compile() +test.compile_function() diff --git a/tests/error/modifier_errors/flag_dagger_assign.err b/tests/error/modifier_errors/flag_dagger_assign.err new file mode 100644 index 000000000..6eb48564c --- /dev/null +++ b/tests/error/modifier_errors/flag_dagger_assign.err @@ -0,0 +1,8 @@ +Error: Invalid expression in dagger (at $FILE:8:4) + | +6 | @guppy +7 | def test() -> None: +8 | x = 3 + | ^^^^^ Assignment found in a dagger context + +Guppy compilation failed due to 1 previous error diff --git a/tests/error/modifier_errors/flag_dagger_assign.py b/tests/error/modifier_errors/flag_dagger_assign.py index ebda41548..da834cc04 100644 --- a/tests/error/modifier_errors/flag_dagger_assign.py +++ b/tests/error/modifier_errors/flag_dagger_assign.py @@ -2,7 +2,8 @@ from guppylang.std.quantum import UnitaryFlags -@guppy(unitary_flags=UnitaryFlags.Dagger) +@guppy.with_unitary_flags(UnitaryFlags.Dagger) +@guppy def test() -> None: x = 3 diff --git a/tests/error/modifier_errors/flag_loop.err b/tests/error/modifier_errors/flag_loop.err new file mode 100644 index 000000000..3a93033bc --- /dev/null +++ b/tests/error/modifier_errors/flag_loop.err @@ -0,0 +1,8 @@ +Error: Invalid expression in dagger (at $FILE:9:13) + | +7 | @guppy +8 | def test() -> None: +9 | for _ in range(3): + | ^^^^^^^^ Assignment found in a dagger context + +Guppy compilation failed due to 1 previous error diff --git a/tests/error/modifier_errors/flag_loop.py b/tests/error/modifier_errors/flag_loop.py index 17f48a9da..5c667abd6 100644 --- a/tests/error/modifier_errors/flag_loop.py +++ b/tests/error/modifier_errors/flag_loop.py @@ -1,10 +1,12 @@ from guppylang.decorator import guppy from guppylang.std.quantum import UnitaryFlags +from guppylang.std.array import array -@guppy(unitary_flags=UnitaryFlags.Dagger) +@guppy.with_unitary_flags(UnitaryFlags.Dagger) +@guppy def test() -> None: - for _ in range(3): + while True: pass diff --git a/tests/error/modifier_errors/higher_order.err b/tests/error/modifier_errors/higher_order.err index e8432761f..69bef82ab 100644 --- a/tests/error/modifier_errors/higher_order.err +++ b/tests/error/modifier_errors/higher_order.err @@ -1,8 +1,8 @@ -Error: Unitary constraint violation (at $FILE:18:8) +Error: Unitary constraint violation (at $FILE:12:4) | -16 | q = qubit() -17 | with dagger: -18 | test_ho(h, q) - | ^^^^^^^^^^^^^ This function cannot be called in a dagger context +10 | def test_ho(f: Callable[[qubit], None], q: qubit) -> None: +11 | # There is no way to use specify flags for f +12 | f(q) + | ^^^^ This function cannot be called in a dagger context Guppy compilation failed due to 1 previous error diff --git a/tests/error/modifier_errors/higher_order.py b/tests/error/modifier_errors/higher_order.py index 51c826324..269878cd0 100644 --- a/tests/error/modifier_errors/higher_order.py +++ b/tests/error/modifier_errors/higher_order.py @@ -5,7 +5,8 @@ # The flag is required to be used in dagger context -@guppy(unitary_flags=UnitaryFlags.Dagger) +@guppy.with_unitary_flags(UnitaryFlags.Dagger) +@guppy def test_ho(f: Callable[[qubit], None], q: qubit) -> None: # There is no way to use specify flags for f f(q) From 7e7bbde531cf8e6f4040871964431ba2072a8ac5 Mon Sep 17 00:00:00 2001 From: hkengo-qtnm Date: Thu, 9 Oct 2025 15:19:43 +0100 Subject: [PATCH 07/13] commented out commit number of tket repository --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index ce3752060..0c0de3db8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,7 +31,7 @@ miette-py = { workspace = true } # Uncomment these to test the latest dependency version during development # hugr = { git = "https://github.com/CQCL/hugr", subdirectory = "hugr-py", rev = "50a2bac" } -tket = { git = "https://github.com/CQCL/tket2", subdirectory = "tket-py", rev = "f0bc211" } +# tket = { git = "https://github.com/CQCL/tket2", subdirectory = "tket-py", rev = "f0bc211" } [build-system] requires = ["hatchling"] From 0c45a4d0a36039c504c1e67ac00ceeb0d7adbb84 Mon Sep 17 00:00:00 2001 From: hkengo-qtnm Date: Thu, 9 Oct 2025 15:55:06 +0100 Subject: [PATCH 08/13] error handling for loops in dagger --- .../checker/cfg_checker.py | 32 ++----------------- .../checker/func_checker.py | 2 ++ .../checker/modifier_checker.py | 2 -- .../checker/unitary_checker.py | 31 +++++++++++++++++- .../definition/function.py | 4 ++- guppylang/src/guppylang/decorator.py | 13 ++++---- tests/error/modifier_errors/flag_loop.err | 14 ++++---- tests/error/modifier_errors/flag_loop.py | 2 +- 8 files changed, 52 insertions(+), 48 deletions(-) diff --git a/guppylang-internals/src/guppylang_internals/checker/cfg_checker.py b/guppylang-internals/src/guppylang_internals/checker/cfg_checker.py index 7fb96fe0c..d83c7151f 100644 --- a/guppylang-internals/src/guppylang_internals/checker/cfg_checker.py +++ b/guppylang-internals/src/guppylang_internals/checker/cfg_checker.py @@ -10,7 +10,7 @@ from dataclasses import dataclass, field from typing import ClassVar, Generic, TypeVar -from guppylang_internals.ast_util import line_col, loop_in_ast +from guppylang_internals.ast_util import line_col from guppylang_internals.cfg.bb import BB from guppylang_internals.cfg.cfg import CFG, BaseCFG from guppylang_internals.checker.core import ( @@ -21,13 +21,12 @@ V, Variable, ) -from guppylang_internals.checker.errors.generic import InvalidUnderDagger from guppylang_internals.checker.expr_checker import ExprSynthesizer, to_bool from guppylang_internals.checker.stmt_checker import StmtChecker from guppylang_internals.diagnostic import Error, Note from guppylang_internals.error import GuppyError from guppylang_internals.tys.param import Parameter -from guppylang_internals.tys.ty import InputFlags, Type, UnitaryFlags +from guppylang_internals.tys.ty import InputFlags, Type Row = Sequence[V] @@ -89,8 +88,6 @@ def check_cfg( inout_vars = [v for v in inputs if InputFlags.Inout in v.flags] cfg.analyze(ass_before, ass_before, [v.name for v in inout_vars]) - check_invalid_in_dagger(cfg) - # We start by compiling the entry BB checked_cfg: CheckedCFG[Variable] = CheckedCFG([v.ty for v in inputs], return_ty) checked_cfg.entry_bb = check_bb( @@ -356,28 +353,3 @@ def reverse_enumerate(xs: list[T]) -> Iterator[tuple[int, T]]: """ for i in range(len(xs) - 1, -1, -1): yield i, xs[i] - -# TODO (k.hirata): This function is supposed to detect loops and assignments in daggered blocks. -# However, this has to be called much earlier since the builder already deconstructs loops -# to CFGs. -# def check_invalid_in_dagger(cfg: CFG) -> None: -# """Check that there are no invalid constructs in a daggered CFG. -# This checker checks the case the UnitaryFlags is given by -# annotation (i.e., not inferred from `with dagger:`). -# """ -# if UnitaryFlags.Dagger not in cfg.unitary_flags: -# return - -# for cfg_bb in cfg.bbs: -# for stmt in cfg_bb.statements: -# loops = loop_in_ast(stmt) -# if len(loops) != 0: -# loop = next(iter(loops)) -# err = InvalidUnderDagger(loop, "Loop") -# raise GuppyError(err) -# # Note: sub-diagnostic for dagger context is not available here - -# if cfg_bb.vars.assigned: -# _, v = next(iter(cfg_bb.vars.assigned.items())) -# err = InvalidUnderDagger(v, "Assignment") -# raise GuppyError(err) diff --git a/guppylang-internals/src/guppylang_internals/checker/func_checker.py b/guppylang-internals/src/guppylang_internals/checker/func_checker.py index dc1e2d9a4..666b92fb6 100644 --- a/guppylang-internals/src/guppylang_internals/checker/func_checker.py +++ b/guppylang-internals/src/guppylang_internals/checker/func_checker.py @@ -16,6 +16,7 @@ from guppylang_internals.checker.cfg_checker import CheckedCFG, check_cfg from guppylang_internals.checker.core import Context, Globals, Place, Variable from guppylang_internals.checker.errors.generic import UnsupportedError +from guppylang_internals.checker.unitary_checker import check_invalid_under_dagger from guppylang_internals.definition.common import DefId from guppylang_internals.definition.ty import TypeDef from guppylang_internals.diagnostic import Error, Help, Note @@ -137,6 +138,7 @@ def check_global_func_def( returns_none = isinstance(ty.output, NoneType) assert ty.input_names is not None + check_invalid_under_dagger(func_def, ty.unitary_flags) cfg = CFGBuilder().build(func_def.body, returns_none, globals, ty.unitary_flags) inputs = [ Variable(x, inp.ty, loc, inp.flags, is_func_input=True) diff --git a/guppylang-internals/src/guppylang_internals/checker/modifier_checker.py b/guppylang-internals/src/guppylang_internals/checker/modifier_checker.py index 075211ffe..ecba3783d 100644 --- a/guppylang-internals/src/guppylang_internals/checker/modifier_checker.py +++ b/guppylang-internals/src/guppylang_internals/checker/modifier_checker.py @@ -4,7 +4,6 @@ from guppylang_internals.ast_util import loop_in_ast, with_loc from guppylang_internals.cfg.bb import BB -from guppylang_internals.cfg.cfg import CFG from guppylang_internals.checker.cfg_checker import check_cfg from guppylang_internals.checker.core import Context, Variable from guppylang_internals.checker.errors.generic import InvalidUnderDagger @@ -17,7 +16,6 @@ InputFlags, NoneType, Type, - UnitaryFlags, ) diff --git a/guppylang-internals/src/guppylang_internals/checker/unitary_checker.py b/guppylang-internals/src/guppylang_internals/checker/unitary_checker.py index 46e737dc3..2a4539281 100644 --- a/guppylang-internals/src/guppylang_internals/checker/unitary_checker.py +++ b/guppylang-internals/src/guppylang_internals/checker/unitary_checker.py @@ -1,7 +1,7 @@ import ast from typing import Any -from guppylang_internals.ast_util import get_type +from guppylang_internals.ast_util import find_nodes, get_type, loop_in_ast from guppylang_internals.checker.cfg_checker import CheckedBB, CheckedCFG from guppylang_internals.checker.core import Place, contains_subscript from guppylang_internals.checker.errors.generic import ( @@ -26,6 +26,35 @@ from guppylang_internals.tys.ty import FunctionType, UnitaryFlags +def check_invalid_under_dagger( + fn_def: ast.FunctionDef, unitary_flags: UnitaryFlags +) -> None: + """Check that there are no invalid constructs in a daggered CFG. + This checker checks the case the UnitaryFlags is given by + annotation (i.e., not inferred from `with dagger:`). + """ + if UnitaryFlags.Dagger not in unitary_flags: + return + + for stmt in fn_def.body: + loops = loop_in_ast(stmt) + if len(loops) != 0: + loop = next(iter(loops)) + err = InvalidUnderDagger(loop, "Loop") + raise GuppyError(err) + # Note: sub-diagnostic for dagger context is not available here + + found = find_nodes( + lambda n: isinstance(n, ast.Assign | ast.AnnAssign | ast.AugAssign), + stmt, + {ast.FunctionDef}, + ) + if len(found) != 0: + assign = next(iter(found)) + err = InvalidUnderDagger(assign, "Assignment") + raise GuppyError(err) + + class BBUnitaryChecker(ast.NodeVisitor): flags: UnitaryFlags diff --git a/guppylang-internals/src/guppylang_internals/definition/function.py b/guppylang-internals/src/guppylang_internals/definition/function.py index 37090b4bd..a088c0d2a 100644 --- a/guppylang-internals/src/guppylang_internals/definition/function.py +++ b/guppylang-internals/src/guppylang_internals/definition/function.py @@ -75,7 +75,9 @@ class RawFunctionDef(ParsableDef): def parse(self, globals: Globals, sources: SourceMap) -> "ParsedFunctionDef": """Parses and checks the user-provided signature of the function.""" func_ast, docstring = parse_py_func(self.python_func, sources) - ty = check_signature(func_ast, globals, self.id, unitary_flags=self.unitary_flags) + ty = check_signature( + func_ast, globals, self.id, unitary_flags=self.unitary_flags + ) return ParsedFunctionDef(self.id, self.name, func_ast, ty, docstring) diff --git a/guppylang/src/guppylang/decorator.py b/guppylang/src/guppylang/decorator.py index 84bbbba71..4dc5fa694 100644 --- a/guppylang/src/guppylang/decorator.py +++ b/guppylang/src/guppylang/decorator.py @@ -2,9 +2,9 @@ import builtins import inspect from collections.abc import Callable, Sequence +from dataclasses import replace from types import FrameType from typing import Any, ParamSpec, TypeVar, cast -from dataclasses import replace from guppylang_internals.ast_util import annotate_location from guppylang_internals.compiler.core import ( @@ -466,7 +466,7 @@ def apply_h(q: qubit) -> None: """ def decorator( - func: GuppyFunctionDefinition[P, T] + func: GuppyFunctionDefinition[P, T], ) -> GuppyFunctionDefinition[P, T]: if not isinstance(func, GuppyFunctionDefinition): raise TypeError( @@ -476,10 +476,10 @@ def decorator( wrapped = func.wrapped # In future we may want to support other function-like definitions here # if not isinstance(wrapped, AnyRawFunctionDef): - if not isinstance(wrapped, RawFunctionDef | RawCustomFunctionDef | RawFunctionDecl): - raise TypeError( - f"Object `{func}` does not have a unitarity annotation" - ) + if not isinstance( + wrapped, RawFunctionDef | RawCustomFunctionDef | RawFunctionDecl + ): + raise TypeError(f"Object `{func}` does not have a unitarity annotation") if wrapped.unitary_flags == flags: return func @@ -491,7 +491,6 @@ def decorator( return decorator - def _parse_expr_string(ty_str: str, parse_err: str, sources: SourceMap) -> ast.expr: """Helper function to parse expressions that are provided as strings. diff --git a/tests/error/modifier_errors/flag_loop.err b/tests/error/modifier_errors/flag_loop.err index 3a93033bc..629ed871f 100644 --- a/tests/error/modifier_errors/flag_loop.err +++ b/tests/error/modifier_errors/flag_loop.err @@ -1,8 +1,10 @@ -Error: Invalid expression in dagger (at $FILE:9:13) - | -7 | @guppy -8 | def test() -> None: -9 | for _ in range(3): - | ^^^^^^^^ Assignment found in a dagger context +Error: Invalid expression in dagger (at $FILE:9:4) + | + 7 | @guppy + 8 | def test() -> None: + 9 | for _ in range(10): + | ^^^^^^^^^^^^^^^^^^^ +10 | pass + | ^^^^^^^^^^^^ Loop found in a dagger context Guppy compilation failed due to 1 previous error diff --git a/tests/error/modifier_errors/flag_loop.py b/tests/error/modifier_errors/flag_loop.py index 5c667abd6..2c9ec6433 100644 --- a/tests/error/modifier_errors/flag_loop.py +++ b/tests/error/modifier_errors/flag_loop.py @@ -6,7 +6,7 @@ @guppy.with_unitary_flags(UnitaryFlags.Dagger) @guppy def test() -> None: - while True: + for _ in range(10): pass From 0354431f4f6003e0bb61546c0632f81f4bda921c Mon Sep 17 00:00:00 2001 From: hkengo-qtnm Date: Thu, 9 Oct 2025 16:02:13 +0100 Subject: [PATCH 09/13] minor --- .../src/guppylang_internals/checker/func_checker.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/guppylang-internals/src/guppylang_internals/checker/func_checker.py b/guppylang-internals/src/guppylang_internals/checker/func_checker.py index 666b92fb6..332d99083 100644 --- a/guppylang-internals/src/guppylang_internals/checker/func_checker.py +++ b/guppylang-internals/src/guppylang_internals/checker/func_checker.py @@ -156,10 +156,8 @@ def check_nested_func_def( func_def: NestedFunctionDef, bb: BB, ctx: Context, - # unitary_flags: (k.hirata) ) -> CheckedNestedFunctionDef: """Type checks a local (nested) function definition.""" - # unitary_flags: (k.hirata) func_ty = check_signature(func_def, ctx.globals) assert func_ty.input_names is not None @@ -220,7 +218,6 @@ def check_nested_func_def( from guppylang.defs import GuppyDefinition from guppylang_internals.definition.function import ParsedFunctionDef - # TODO (k.hirata): unitary_flags func = ParsedFunctionDef(def_id, func_def.name, func_def, func_ty, None) DEF_STORE.register_def(func, None) ENGINE.parsed[def_id] = func From 520f7c80cd325f96834b5317a40b10ea8659fe6c Mon Sep 17 00:00:00 2001 From: hkengo-qtnm Date: Thu, 9 Oct 2025 16:02:36 +0100 Subject: [PATCH 10/13] update modifier integration test --- tests/integration/test_modifier.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/tests/integration/test_modifier.py b/tests/integration/test_modifier.py index 2108d5a85..b82046216 100644 --- a/tests/integration/test_modifier.py +++ b/tests/integration/test_modifier.py @@ -1,5 +1,5 @@ from guppylang.decorator import guppy -from guppylang.std.quantum import qubit +from guppylang.std.quantum import qubit, UnitaryFlags from guppylang.std.num import nat from guppylang.std.builtins import owned from guppylang.std.array import array @@ -101,6 +101,7 @@ def bar(q: qubit) -> None: def test_free_linear_variable_in_modifier(validate): T = guppy.type_var("T", copyable=False, droppable=False) + @guppy.with_unitary_flags(UnitaryFlags.Control) @guppy.declare def use(a: T) -> None: ... @@ -123,9 +124,6 @@ def test_free_copyable_variable_in_modifier(validate): @guppy.declare def use(a: T) -> None: ... - @guppy.declare - def discard(a: T @ owned) -> None: ... - @guppy def bar(q: array[qubit, 3]) -> None: a = 3 From cdaf60993c48a3969e5d2817a932b36efcaf608f Mon Sep 17 00:00:00 2001 From: hkengo-qtnm Date: Thu, 9 Oct 2025 16:14:28 +0100 Subject: [PATCH 11/13] comments --- guppylang-internals/src/guppylang_internals/cfg/builder.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/guppylang-internals/src/guppylang_internals/cfg/builder.py b/guppylang-internals/src/guppylang_internals/cfg/builder.py index a8c10ae66..5f3101a46 100644 --- a/guppylang-internals/src/guppylang_internals/cfg/builder.py +++ b/guppylang-internals/src/guppylang_internals/cfg/builder.py @@ -308,7 +308,10 @@ def visit_With(self, node: ast.With, bb: BB, jumps: Jumps) -> BB | None: modifier = self._handle_withitem(item) new_node.push_modifier(modifier) - # TODO: its parent's flags need to be added too + # FIXME: Currently, the unitary flags is not set correctly if there are nested + # `with` blocks. This is because the inner block's unitary flags are not + # propagated from the outer block. The following line should calculate the sum + # of the unitary flags of the outer block and modifiers of this `with` block. unitary_flags = new_node.flags() object.__setattr__(cfg, "unitary_flags", unitary_flags) From 26e00e2a33dabe70aacc6348111b4b18f5de3fe2 Mon Sep 17 00:00:00 2001 From: hkengo-qtnm Date: Thu, 9 Oct 2025 16:22:24 +0100 Subject: [PATCH 12/13] added a test --- .../src/guppylang_internals/cfg/builder.py | 6 +++--- tests/error/modifier_errors/flags_nested.err | 8 ++++++++ tests/error/modifier_errors/flags_nested.py | 20 +++++++++++++++++++ 3 files changed, 31 insertions(+), 3 deletions(-) create mode 100644 tests/error/modifier_errors/flags_nested.err create mode 100644 tests/error/modifier_errors/flags_nested.py diff --git a/guppylang-internals/src/guppylang_internals/cfg/builder.py b/guppylang-internals/src/guppylang_internals/cfg/builder.py index 5f3101a46..80c56ed20 100644 --- a/guppylang-internals/src/guppylang_internals/cfg/builder.py +++ b/guppylang-internals/src/guppylang_internals/cfg/builder.py @@ -309,9 +309,9 @@ def visit_With(self, node: ast.With, bb: BB, jumps: Jumps) -> BB | None: new_node.push_modifier(modifier) # FIXME: Currently, the unitary flags is not set correctly if there are nested - # `with` blocks. This is because the inner block's unitary flags are not - # propagated from the outer block. The following line should calculate the sum - # of the unitary flags of the outer block and modifiers of this `with` block. + # `with` blocks. This is because the outer block's unitary flags are not + # propagated to the outer block. The following line should calculate the sum + # of the unitary flags of the outer block and modifiers applied in this `with` block. unitary_flags = new_node.flags() object.__setattr__(cfg, "unitary_flags", unitary_flags) diff --git a/tests/error/modifier_errors/flags_nested.err b/tests/error/modifier_errors/flags_nested.err new file mode 100644 index 000000000..39d1a2a9d --- /dev/null +++ b/tests/error/modifier_errors/flags_nested.err @@ -0,0 +1,8 @@ +Error: Unitary constraint violation (at $FILE:17:12) + | +15 | with dagger: +16 | with power(2): +17 | foo(q) + | ^^^^^^ This function cannot be called in a dagger context + +Guppy compilation failed due to 1 previous error diff --git a/tests/error/modifier_errors/flags_nested.py b/tests/error/modifier_errors/flags_nested.py new file mode 100644 index 000000000..dd4f2b337 --- /dev/null +++ b/tests/error/modifier_errors/flags_nested.py @@ -0,0 +1,20 @@ +from guppylang.decorator import guppy +from guppylang.std.quantum import qubit, UnitaryFlags +from guppylang.std.array import array + + +@guppy.with_unitary_flags(UnitaryFlags.Power) +@guppy +def foo(q: qubit) -> None: + pass + + +@guppy +def test() -> None: + q = qubit() + with dagger: + with power(2): + foo(q) + + +test.compile() From d3b1577f783d735f1a19c121bf9520f2c1681bdc Mon Sep 17 00:00:00 2001 From: hkengo-qtnm Date: Fri, 10 Oct 2025 01:18:40 +0100 Subject: [PATCH 13/13] small integration test using numpy --- .../src/guppylang_internals/cfg/builder.py | 3 +- pyproject.toml | 3 +- tests/integration/test_modifier_emulate.py | 100 ++++++++++++++++++ 3 files changed, 104 insertions(+), 2 deletions(-) create mode 100644 tests/integration/test_modifier_emulate.py diff --git a/guppylang-internals/src/guppylang_internals/cfg/builder.py b/guppylang-internals/src/guppylang_internals/cfg/builder.py index 80c56ed20..8c317ba2f 100644 --- a/guppylang-internals/src/guppylang_internals/cfg/builder.py +++ b/guppylang-internals/src/guppylang_internals/cfg/builder.py @@ -311,7 +311,8 @@ def visit_With(self, node: ast.With, bb: BB, jumps: Jumps) -> BB | None: # FIXME: Currently, the unitary flags is not set correctly if there are nested # `with` blocks. This is because the outer block's unitary flags are not # propagated to the outer block. The following line should calculate the sum - # of the unitary flags of the outer block and modifiers applied in this `with` block. + # of the unitary flags of the outer block and modifiers applied in this + # `with` block. unitary_flags = new_node.flags() object.__setattr__(cfg, "unitary_flags", unitary_flags) diff --git a/pyproject.toml b/pyproject.toml index 0c0de3db8..0c6b8007a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,7 +31,8 @@ miette-py = { workspace = true } # Uncomment these to test the latest dependency version during development # hugr = { git = "https://github.com/CQCL/hugr", subdirectory = "hugr-py", rev = "50a2bac" } -# tket = { git = "https://github.com/CQCL/tket2", subdirectory = "tket-py", rev = "f0bc211" } +tket = { git = "https://github.com/CQCL/tket2", subdirectory = "tket-py", rev = "a8c8bd1" } +selene-hugr-qis-compiler = { git = "https://github.com/CQCL/tket2", subdirectory = "qis-compiler", rev = "a8c8bd1" } [build-system] requires = ["hatchling"] diff --git a/tests/integration/test_modifier_emulate.py b/tests/integration/test_modifier_emulate.py new file mode 100644 index 000000000..f8655c499 --- /dev/null +++ b/tests/integration/test_modifier_emulate.py @@ -0,0 +1,100 @@ +import numpy as np +from guppylang.decorator import guppy +from guppylang.std.array import array +from guppylang.std.debug import state_result +from guppylang.std.quantum import ( + discard, + discard_array, + qubit, + UnitaryFlags, + cx, + v, + h, + x, + s, + t, + toffoli, + sdg, + ry, + crz, +) +from guppylang.std.angles import angle + +# Dummy variables to suppress Undefined name +# TODO: `ruff` fails when without these, which need to be fixed +dagger = object() +control = object() + + +@guppy.with_unitary_flags(UnitaryFlags.Unitary) +@guppy +def foo(q1: qubit, q2: qubit, q3: qubit, q4: qubit) -> None: + h(q1) + cx(q1, q2) + v(q1) + h(q1) + x(q3) + with dagger: + s(q1) + with control(q2): + toffoli(q1, q3, q4) + t(q3) + sdg(q1) + ry(q1, angle(0.12)) + crz(q1, q3, angle(0.38)) + h(q1) + toffoli(q1, q2, q3) + s(q3) + + +def test_dagger(): + @guppy + def dagger_involution() -> None: + q1 = qubit() + q2 = qubit() + q3 = qubit() + q4 = qubit() + + with dagger: + foo(q1, q2, q3, q4) + foo(q1, q2, q3, q4) + q = array(q1, q2, q3, q4) + state_result("zero", q) + discard_array(q) + + shots = dagger_involution.emulator(n_qubits=6).statevector_sim().with_seed(1).run() + + for states in shots.partial_state_dicts(): + state_vector1 = states["zero"].as_single_state() + + diff = np.abs( + state_vector1 - np.array([1.0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]) + ) + assert np.all(diff < 1e-6), f"State vector non zero: {diff}" + + +def test_ctrl(): + @guppy + def test_ctrl() -> None: + q = array(qubit() for _ in range(4)) + foo(q[0], q[1], q[2], q[3]) + state_result("foo", q) + discard_array(q) + + q = array(qubit() for _ in range(4)) + c = qubit() + x(c) + with control(c): + foo(q[0], q[1], q[2], q[3]) + state_result("ctrl_foo", q) + discard_array(q) + discard(c) + + shots = test_ctrl.emulator(n_qubits=7).statevector_sim().with_seed(1).run() + + for states in shots.partial_state_dicts(): + state_vector1 = states["foo"].as_single_state() + state_vector2 = states["ctrl_foo"].as_single_state() + + diff = np.abs(state_vector1 - state_vector2) + assert np.all(diff < 1e-6), f"State vectors are different: {diff}"