diff --git a/guppylang-internals/src/guppylang_internals/ast_util.py b/guppylang-internals/src/guppylang_internals/ast_util.py index 454f431f8..16f305ebd 100644 --- a/guppylang-internals/src/guppylang_internals/ast_util.py +++ b/guppylang-internals/src/guppylang_internals/ast_util.py @@ -1,6 +1,7 @@ import ast import textwrap from collections.abc import Callable, Mapping, Sequence +from copy import deepcopy from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar, cast @@ -418,3 +419,55 @@ def parse_source(source_lines: list[str], line_offset: int) -> tuple[str, ast.AS else: node = ast.parse(source).body[0] return source, node, line_offset + + +ImportStmt = ast.Import | ast.ImportFrom + + +class ImportMap: + """Maps imported names to their fully qualified module paths.""" + + imports_by_name: dict[str, tuple[ImportStmt, ast.alias]] + used_names: set[str] + + def __init__(self) -> None: + self.imports_by_name = {} + self.used_names = set() + + def register_import(self, name: str, stmt: ImportStmt, alias: ast.alias) -> None: + self.imports_by_name[name] = (stmt, alias) + + def has(self, name: str) -> bool: + return name in self.imports_by_name + + def use(self, name: str) -> None: + """Marks a name as used, causing the ast generated by `dump_ast` to include an + import statement for it.""" + # if not self.has(name): + # raise KeyError( + # f"Name `{name}` is not registered in this import map. " + # f"Perhaps it is not a top level import?" + # ) + + return self.used_names.add(name) + + def dump_ast(self) -> list[ImportStmt]: + """Returns a list of AST statements representing the imports registered in + this map.""" + used_aliases_by_original_stmt = dict[ImportStmt, list[ast.alias]]() + for name, (stmt, alias) in self.imports_by_name.items(): + if name not in self.used_names: + continue + + if stmt not in used_aliases_by_original_stmt: + used_aliases_by_original_stmt[stmt] = [] + + used_aliases_by_original_stmt[stmt].append(alias) + + stmts = [] + for stmt, used_aliases in used_aliases_by_original_stmt.items(): + new_stmt = deepcopy(stmt) + new_stmt.names = used_aliases + stmts.append(new_stmt) + + return stmts diff --git a/guppylang-internals/src/guppylang_internals/checker/func_checker.py b/guppylang-internals/src/guppylang_internals/checker/func_checker.py index b3695c1d1..ff19fc0a5 100644 --- a/guppylang-internals/src/guppylang_internals/checker/func_checker.py +++ b/guppylang-internals/src/guppylang_internals/checker/func_checker.py @@ -224,8 +224,8 @@ def check_nested_func_def( func_ty, None, # Even though global, this function will be private to the built hugr, - # so the hugr name does not really matter. - hugr_name=func_def.name, + # so the link name does not really matter. + link_name=func_def.name, ) DEF_STORE.register_def(func, None) ENGINE.parsed[def_id] = func diff --git a/guppylang-internals/src/guppylang_internals/definition/declaration.py b/guppylang-internals/src/guppylang_internals/definition/declaration.py index 191fa42a1..3fd25cdec 100644 --- a/guppylang-internals/src/guppylang_internals/definition/declaration.py +++ b/guppylang-internals/src/guppylang_internals/definition/declaration.py @@ -6,7 +6,13 @@ from hugr.build import function as hf from hugr.build.dfg import DefinitionBuilder, OpVar -from guppylang_internals.ast_util import AstNode, has_empty_body, with_loc, with_type +from guppylang_internals.ast_util import ( + AstNode, + ImportMap, + has_empty_body, + with_loc, + with_type, +) from guppylang_internals.checker.core import Context, Globals from guppylang_internals.checker.expr_checker import check_call, synthesize_call from guppylang_internals.checker.func_checker import check_signature @@ -19,6 +25,7 @@ from guppylang_internals.definition.function import ( PyFunc, compile_call, + generate_stub_from_def, load_with_args, parse_py_func, ) @@ -69,11 +76,11 @@ class RawFunctionDecl(ParsableDef): unitary_flags: UnitaryFlags = field(default=UnitaryFlags.NoFlags, kw_only=True) - hugr_name: InitVar[str | None] = field(default=None, kw_only=True) - _user_set_hugr_name: str | None = field(default=None, init=False) + link_name: InitVar[str | None] = field(default=None, kw_only=True) + _user_set_link_name: str | None = field(default=None, init=False) - def __post_init__(self, hugr_name: str | None) -> None: - object.__setattr__(self, "_user_set_hugr_name", hugr_name) + def __post_init__(self, link_name: str | None) -> None: + object.__setattr__(self, "_user_set_link_name", link_name) def parse(self, globals: Globals, sources: SourceMap) -> "CheckedFunctionDecl": """Parses and checks the user-provided signature of the function.""" @@ -81,13 +88,13 @@ def parse(self, globals: Globals, sources: SourceMap) -> "CheckedFunctionDecl": ty = check_signature( func_ast, globals, self.id, unitary_flags=self.unitary_flags ) - hugr_name = f"{self.python_func.__module__}.{self.python_func.__qualname__}" - if self._user_set_hugr_name is not None: - hugr_name = self._user_set_hugr_name + link_name = f"{self.python_func.__module__}.{self.python_func.__qualname__}" + if self._user_set_link_name is not None: + link_name = self._user_set_link_name elif (parent_ty_id := DEF_STORE.impl_parents.get(self.id)) is not None: parent = ENGINE.get_parsed(parent_ty_id) if isinstance(parent, ParsedStructDef): - hugr_name = f"{parent.hugr_name_prefix}.{self.python_func.__name__}" + link_name = f"{parent.link_name_prefix}.{self.python_func.__name__}" if not has_empty_body(func_ast): raise GuppyError(BodyNotEmptyError(func_ast.body[0], self.name)) @@ -100,9 +107,13 @@ def parse(self, globals: Globals, sources: SourceMap) -> "CheckedFunctionDecl": func_ast, ty, docstring, - hugr_name, + link_name, + module=self.python_func.__module__, ) + def generate_guppy_declare_decorator(self, import_map: ImportMap) -> ast.expr: + raise NotImplementedError("Must be implemented by a subclass!") + @dataclass(frozen=True) class CheckedFunctionDecl(CompilableDef, CallableDef): @@ -113,7 +124,8 @@ class CheckedFunctionDecl(CompilableDef, CallableDef): defined_at: ast.FunctionDef docstring: str | None - hugr_name: str + link_name: str + module: str | None = field(default=None, kw_only=True) def check_call( self, args: list[ast.expr], ty: Type, node: AstNode, ctx: Context @@ -142,17 +154,25 @@ def compile_outer( ) module: hf.Module = module - node = module.declare_function(self.hugr_name, self.ty.to_hugr_poly(ctx)) + node = module.declare_function(self.link_name, self.ty.to_hugr_poly(ctx)) return CompiledFunctionDecl( self.id, self.name, self.defined_at, self.ty, self.docstring, - self.hugr_name, + self.link_name, node, + module=self.module, ) + def stub(self) -> ast.FunctionDef: + """Generates a stub function declaration with an empty body.""" + raw_def = DEF_STORE.raw_defs[self.id] + assert isinstance(raw_def, RawFunctionDecl) + + return generate_stub_from_def(raw_def, self.defined_at) + @dataclass(frozen=True) class CompiledFunctionDecl( diff --git a/guppylang-internals/src/guppylang_internals/definition/function.py b/guppylang-internals/src/guppylang_internals/definition/function.py index 8ee3c4424..dd05f4197 100644 --- a/guppylang-internals/src/guppylang_internals/definition/function.py +++ b/guppylang-internals/src/guppylang_internals/definition/function.py @@ -1,6 +1,7 @@ import ast import inspect from collections.abc import Callable, Sequence +from copy import deepcopy from dataclasses import InitVar, dataclass, field from typing import TYPE_CHECKING, Any @@ -12,6 +13,7 @@ from guppylang_internals.ast_util import ( AstNode, + ImportMap, annotate_location, parse_source, with_loc, @@ -48,15 +50,17 @@ CompiledHugrNodeDef, ) from guppylang_internals.engine import DEF_STORE, ENGINE -from guppylang_internals.error import GuppyError +from guppylang_internals.error import GuppyError, InternalGuppyError from guppylang_internals.nodes import GlobalCall from guppylang_internals.span import SourceMap +from guppylang_internals.tys.parsing import _parse_delayed_annotation from guppylang_internals.tys.subst import Inst, Subst from guppylang_internals.tys.ty import FunctionType, Type, UnitaryFlags, type_to_row if TYPE_CHECKING: from hugr.tys import Visibility + from guppylang_internals.definition.declaration import RawFunctionDecl from guppylang_internals.tys.param import Parameter PyFunc = Callable[..., Any] @@ -85,11 +89,11 @@ class RawFunctionDef(ParsableDef): metadata: GuppyMetadata | None = field(default=None, kw_only=True) - hugr_name: InitVar[str | None] = field(default=None, kw_only=True) - _user_set_hugr_name: str | None = field(default=None, init=False) + link_name: InitVar[str | None] = field(default=None, kw_only=True) + _user_set_link_name: str | None = field(default=None, init=False) - def __post_init__(self, hugr_name: str | None) -> None: - object.__setattr__(self, "_user_set_hugr_name", hugr_name) + def __post_init__(self, link_name: str | None) -> None: + object.__setattr__(self, "_user_set_link_name", link_name) def parse(self, globals: Globals, sources: SourceMap) -> "ParsedFunctionDef": """Parses and checks the user-provided signature of the function.""" @@ -98,13 +102,13 @@ def parse(self, globals: Globals, sources: SourceMap) -> "ParsedFunctionDef": func_ast, globals, self.id, unitary_flags=self.unitary_flags ) - hugr_name = f"{self.python_func.__module__}.{self.python_func.__qualname__}" - if self._user_set_hugr_name is not None: - hugr_name = self._user_set_hugr_name + link_name = f"{self.python_func.__module__}.{self.python_func.__qualname__}" + if self._user_set_link_name is not None: + link_name = self._user_set_link_name elif (parent_ty_id := DEF_STORE.impl_parents.get(self.id)) is not None: parent = ENGINE.get_parsed(parent_ty_id) if isinstance(parent, ParsedStructDef): - hugr_name = f"{parent.hugr_name_prefix}.{self.python_func.__name__}" + link_name = f"{parent.link_name_prefix}.{self.python_func.__name__}" return ParsedFunctionDef( self.id, @@ -112,10 +116,14 @@ def parse(self, globals: Globals, sources: SourceMap) -> "ParsedFunctionDef": func_ast, ty, docstring, - hugr_name, + link_name, + module=self.python_func.__module__, metadata=self.metadata, ) + def generate_guppy_declare_decorator(self, import_map: ImportMap) -> ast.expr: + raise NotImplementedError("Must be implemented by a subclass!") + @dataclass(frozen=True) class ParsedFunctionDef(CheckableDef, CallableDef): @@ -130,13 +138,14 @@ class ParsedFunctionDef(CheckableDef, CallableDef): defined_at: The AST node where the function was defined. ty: The type of the function. docstring: The docstring of the function. - hugr_name: The name that the Hugr node for this function will receive. + link_name: The name that the Hugr node for this function will receive. """ defined_at: ast.FunctionDef ty: FunctionType docstring: str | None - hugr_name: str + link_name: str + module: str | None = field(default=None, kw_only=True) description: str = field(default="function", init=False) @@ -152,8 +161,9 @@ def check(self, globals: Globals) -> "CheckedFunctionDef": self.defined_at, self.ty, self.docstring, - self.hugr_name, + self.link_name, cfg, + module=self.module, metadata=self.metadata, ) @@ -175,6 +185,13 @@ def synthesize_call( node = with_loc(node, GlobalCall(def_id=self.id, args=args, type_args=inst)) return with_type(ty, node), ty + def stub(self) -> ast.FunctionDef: + """Generates a stub function declaration with an empty body.""" + raw_def = DEF_STORE.raw_defs[self.id] + assert isinstance(raw_def, RawFunctionDef) + + return generate_stub_from_def(raw_def, self.defined_at) + @dataclass(frozen=True) class CheckedFunctionDef(ParsedFunctionDef, MonomorphizableDef): @@ -189,7 +206,7 @@ class CheckedFunctionDef(ParsedFunctionDef, MonomorphizableDef): defined_at: The AST node where the function was defined. ty: The type of the function. docstring: The docstring of the function. - hugr_name: The name that the Hugr node for this function will receive. + link_name: The name that the Hugr node for this function will receive. cfg: The type- and linearity-checked CFG for the function body. """ @@ -217,7 +234,7 @@ def monomorphize( hugr_ty = mono_ty.to_hugr_poly(ctx) visibility: Visibility = "Public" if self.id in ctx.exported_defs else "Private" func_def = module.module_root_builder().define_function( - self.hugr_name, + self.link_name, hugr_ty.body.input, hugr_ty.body.output, hugr_ty.params, @@ -235,9 +252,10 @@ def monomorphize( mono_args, mono_ty, self.docstring, - self.hugr_name, + self.link_name, self.cfg, func_def, + module=self.module, metadata=self.metadata, ) @@ -255,7 +273,7 @@ class CompiledFunctionDef( mono_args: Partial monomorphization of the generic type parameters. ty: The type of the function after partial monomorphization. docstring: The docstring of the function. - hugr_name: The name of the Hugr node corresponding to this function. + link_name: The name of the Hugr node corresponding to this function. cfg: The type- and linearity-checked CFG for the function body. func_def: The Hugr function definition. """ @@ -334,3 +352,50 @@ def parse_py_func(f: PyFunc, sources: SourceMap) -> tuple[ast.FunctionDef, str | if not isinstance(func_ast, ast.FunctionDef): raise GuppyError(ExpectedError(func_ast, "a function definition")) return parse_function_with_docstring(func_ast) + + +def _mark_names_used(expr: ast.expr, import_map: ImportMap) -> None: + """Marks all names used in the given expression as used in the import map.""" + for node in ast.walk(expr): + if isinstance(node, ast.Name): + import_map.use(node.id) + elif isinstance(node, ast.Constant) and isinstance(node.value, str): + parsed = _parse_delayed_annotation(node.value, node) + _mark_names_used(parsed, import_map) + else: + continue + + +def generate_stub_from_def( + raw_def: "RawFunctionDef | RawFunctionDecl", source_def: ast.FunctionDef +) -> ast.FunctionDef: + if not hasattr(source_def, "file"): + # Should be set in `ast_util.annotate_location(...)`. + raise InternalGuppyError("Source file not set for source def.") + import_map = DEF_STORE.sources.imports[source_def.file] + + func_def = deepcopy(source_def) + for arg in func_def.args.args: + if arg.annotation is not None: + _mark_names_used(arg.annotation, import_map) + if func_def.returns is not None: + _mark_names_used(func_def.returns, import_map) + func_def.body = [ + *( + [ast.Expr(ast.Constant(raw_def.python_func.__doc__))] + if raw_def.python_func.__doc__ + else [] + ), + ast.Expr(ast.Constant(...)), + ] + # TODO register all used imports from function args and return type + # (and type params?) + func_def.decorator_list = [raw_def.generate_guppy_declare_decorator(import_map)] + + # We cannot know these values + func_def.lineno = -1 + func_def.col_offset = -1 + func_def.end_lineno = -1 + func_def.end_col_offset = -1 + + return func_def diff --git a/guppylang-internals/src/guppylang_internals/definition/struct.py b/guppylang-internals/src/guppylang_internals/definition/struct.py index 49661465a..e1d46de7a 100644 --- a/guppylang-internals/src/guppylang_internals/definition/struct.py +++ b/guppylang-internals/src/guppylang_internals/definition/struct.py @@ -116,11 +116,11 @@ class RawStructDef(TypeDef, ParsableDef): python_class: type params: None = field(default=None, init=False) # Params not known yet - hugr_name: InitVar[str | None] = field(default=None, kw_only=True) - _user_set_hugr_name: str | None = field(default=None, init=False) + link_name: InitVar[str | None] = field(default=None, kw_only=True) + _user_set_link_name: str | None = field(default=None, init=False) - def __post_init__(self, hugr_name: str | None) -> None: - object.__setattr__(self, "_user_set_hugr_name", hugr_name) + def __post_init__(self, link_name: str | None) -> None: + object.__setattr__(self, "_user_set_link_name", link_name) def parse(self, globals: Globals, sources: SourceMap) -> "ParsedStructDef": """Parses the raw class object into an AST and checks that it is well-formed.""" @@ -202,13 +202,13 @@ def parse(self, globals: Globals, sources: SourceMap) -> "ParsedStructDef": x = overridden.pop() raise GuppyError(DuplicateFieldError(used_func_names[x], self.name, x)) - hugr_name_prefix = ( - self._user_set_hugr_name + link_name_prefix = ( + self._user_set_link_name or f"{self.python_class.__module__}.{self.python_class.__qualname__}" ) return ParsedStructDef( - self.id, self.name, cls_def, params, fields, hugr_name_prefix + self.id, self.name, cls_def, params, fields, link_name_prefix ) def check_instantiate( @@ -224,7 +224,7 @@ class ParsedStructDef(TypeDef, CheckableDef): defined_at: ast.ClassDef params: Sequence[Parameter] fields: Sequence[UncheckedStructField] - hugr_name_prefix: str + link_name_prefix: str def check(self, globals: Globals) -> "CheckedStructDef": """Checks that all struct fields have valid types.""" diff --git a/guppylang-internals/src/guppylang_internals/span.py b/guppylang-internals/src/guppylang_internals/span.py index b35366d9d..9fb68fcd3 100644 --- a/guppylang-internals/src/guppylang_internals/span.py +++ b/guppylang-internals/src/guppylang_internals/span.py @@ -3,9 +3,10 @@ import ast import linecache from dataclasses import dataclass +from pathlib import Path from typing import TypeAlias -from guppylang_internals.ast_util import get_file, get_line_offset +from guppylang_internals.ast_util import ImportMap, get_file, get_line_offset from guppylang_internals.error import InternalGuppyError from guppylang_internals.ipython_inspect import normalize_ipython_dummy_files @@ -137,9 +138,11 @@ class SourceMap: """ sources: dict[str, SourceLines] + imports: dict[str, ImportMap] def __init__(self) -> None: self.sources = {} + self.imports = {} def add_file(self, file: str, content: str | None = None) -> None: """Registers a new source file.""" @@ -148,6 +151,30 @@ def add_file(self, file: str, content: str | None = None) -> None: else: self.sources[file] = content.splitlines(keepends=False) + with Path(file).open() as f: + file_source = f.read() + source_file_ast = ast.parse(file_source, file) + imports = ImportMap() + + def process_stmt(stmt: ast.stmt) -> None: + match stmt: + case ast.Import(names) | ast.ImportFrom(_, names, _): + for alias in names: + if alias.asname is not None: + imports.register_import(alias.asname, stmt, alias) + else: + imports.register_import(alias.name, stmt, alias) + case ast.If(ast.Name("TYPE_CHECKING"), body, _): + for type_checking_stmt in body: + process_stmt(type_checking_stmt) + case _: + pass + + for top_level_stmt in source_file_ast.body: + process_stmt(top_level_stmt) + + self.imports[file] = imports + def span_lines(self, span: Span, prefix_lines: int = 0) -> list[str]: return self.sources[span.file][ span.start.line - prefix_lines - 1 : span.end.line diff --git a/guppylang/src/guppylang/decorator.py b/guppylang/src/guppylang/decorator.py index 9ce3f975b..8048db475 100644 --- a/guppylang/src/guppylang/decorator.py +++ b/guppylang/src/guppylang/decorator.py @@ -5,7 +5,7 @@ from types import FrameType from typing import Any, NamedTuple, ParamSpec, TypedDict, TypeVar, cast, overload -from guppylang_internals.ast_util import annotate_location +from guppylang_internals.ast_util import ImportMap, annotate_location from guppylang_internals.compiler.core import ( CompilerContext, ) @@ -95,7 +95,7 @@ class GuppyKwargs(TypedDict, total=False): dagger: bool power: bool max_qubits: int - hugr_name: str + link_name: str class GuppyStructKwargs(TypedDict, total=False): @@ -103,7 +103,65 @@ class GuppyStructKwargs(TypedDict, total=False): `@guppy.struct` decorator. """ - hugr_name: str + link_name: str + + +def _generate_guppy_declare_decorator( + parsed: "ParsedGuppyKwargs", import_map: ImportMap +) -> ast.expr: + """Generates an AST expression that reconstructs this function definition as a + call to the `@guppy.declare` decorator with the same parameters as the original + definition. + """ + kwargs = [ + ast.keyword(keyword, ast.Constant(value)) # type: ignore[arg-type] + for keyword, value in unparse_kwargs(parsed).items() + ] + + # Workaround to get the name of the exported decorator symbol + decorator_name_str = f"{guppy=}".split("=")[0] + import_map.use(decorator_name_str) + decorator = ast.Attribute( + ast.Name(id=decorator_name_str, ctx=ast.Load()), + attr=_Guppy.declare.__name__, + ctx=ast.Load(), + ) + if len(kwargs) == 0: + return decorator + + return ast.Call(func=decorator, args=[], keywords=kwargs) + + +class _RawReconstructableFunctionDef(RawFunctionDef): + def generate_guppy_declare_decorator(self, import_map: ImportMap) -> ast.expr: + """Generates an AST expression that reconstructs this function definition as a + call to the `@guppy.declare` decorator with the same parameters as the original + definition. + """ + return _generate_guppy_declare_decorator( + ParsedGuppyKwargs( + flags=self.unitary_flags, + metadata=self.metadata or GuppyMetadata(), + link_name=self._user_set_link_name, + ), + import_map, + ) + + +class _RawReconstructableFunctionDecl(RawFunctionDecl): + def generate_guppy_declare_decorator(self, import_map: ImportMap) -> ast.expr: + """Generates an AST expression that reconstructs this function declaration as a + call to the `@guppy.declare` decorator with the same parameters as the original + declaration. + """ + return _generate_guppy_declare_decorator( + ParsedGuppyKwargs( + flags=self.unitary_flags, + metadata=GuppyMetadata(), + link_name=self._user_set_link_name, + ), + import_map, + ) class _Guppy: @@ -127,14 +185,14 @@ def dec( f: Callable[P, T], kwargs: GuppyKwargs ) -> GuppyFunctionDefinition[P, T]: parsed = _parse_kwargs(kwargs) - defn = RawFunctionDef( + defn = _RawReconstructableFunctionDef( DefId.fresh(), f.__name__, None, f, unitary_flags=parsed.flags, metadata=parsed.metadata, - hugr_name=parsed.hugr_name, + link_name=parsed.link_name, ) DEF_STORE.register_def(defn, get_calling_frame()) return GuppyFunctionDefinition(defn) @@ -218,7 +276,7 @@ def add_fields(self: "MyStruct") -> int: return self.field2 + self.field2 # Add optional parameters - @guppy.struct(hugr_name="my_struct") + @guppy.struct(link_name="my_struct") class MyStruct2: field1: int field2: int @@ -230,7 +288,7 @@ def dec(cls: builtins.type[T], kwargs: GuppyStructKwargs) -> GuppyDefinition: cls.__name__, None, cls, - hugr_name=kwargs.pop("hugr_name", None), + link_name=kwargs.pop("link_name", None), ) frame = get_calling_frame() DEF_STORE.register_def(defn, frame) @@ -332,13 +390,13 @@ def dec( f: Callable[P, T], kwargs: GuppyKwargs ) -> GuppyFunctionDefinition[P, T]: parsed = _parse_kwargs(kwargs) - defn = RawFunctionDecl( + defn = _RawReconstructableFunctionDecl( DefId.fresh(), f.__name__, None, f, unitary_flags=parsed.flags, - hugr_name=parsed.hugr_name, + link_name=parsed.link_name, ) DEF_STORE.register_def(defn, get_calling_frame()) return GuppyFunctionDefinition(defn) @@ -689,7 +747,7 @@ def _with_optional_kwargs( class ParsedGuppyKwargs(NamedTuple): flags: UnitaryFlags metadata: GuppyMetadata - hugr_name: str | None + link_name: str | None @hide_trace @@ -710,7 +768,7 @@ def _parse_kwargs(kwargs: GuppyKwargs) -> ParsedGuppyKwargs: metadata = GuppyMetadata() metadata.max_qubits.value = kwargs.pop("max_qubits", None) - hugr_name = kwargs.pop("hugr_name", None) + link_name = kwargs.pop("link_name", None) if remaining := next(iter(kwargs), None): err = f"Unknown keyword argument: `{remaining}`" @@ -719,8 +777,35 @@ def _parse_kwargs(kwargs: GuppyKwargs) -> ParsedGuppyKwargs: return ParsedGuppyKwargs( flags=flags, metadata=metadata, - hugr_name=hugr_name, + link_name=link_name, ) +def unparse_kwargs(parsed: ParsedGuppyKwargs) -> GuppyKwargs: + """Unparses the given `ParsedGuppyKwargs` back into a form that can be passed to the + `@guppy` decorator. + """ + kwargs = GuppyKwargs() + match parsed.flags: + case UnitaryFlags.NoFlags: + pass + case UnitaryFlags.Unitary: + kwargs["unitary"] = True + case value: + if value & UnitaryFlags.Control: # type: ignore[operator] + kwargs["control"] = True + if value & UnitaryFlags.Dagger: # type: ignore[operator] + kwargs["dagger"] = True + if value & UnitaryFlags.Power: # type: ignore[operator] + kwargs["power"] = True + + if parsed.metadata.max_qubits.value is not None: + kwargs["max_qubits"] = parsed.metadata.max_qubits.value + + if parsed.link_name is not None: + kwargs["link_name"] = parsed.link_name + + return kwargs + + guppy = cast("_Guppy", _DummyGuppy()) if sphinx_running() else _Guppy() diff --git a/guppylang/src/guppylang/defs.py b/guppylang/src/guppylang/defs.py index 6f47446a9..e87da27f4 100644 --- a/guppylang/src/guppylang/defs.py +++ b/guppylang/src/guppylang/defs.py @@ -4,17 +4,20 @@ with the compiler-internal definition objects in the `definitions` module. """ +import ast +import importlib from collections.abc import Sequence from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, ClassVar, Generic, ParamSpec, TypeVar, cast +from typing import Any, ClassVar, Generic, ParamSpec, TypeVar, cast import guppylang_internals from guppylang_internals.definition.common import DefId -from guppylang_internals.definition.function import RawFunctionDef +from guppylang_internals.definition.declaration import CheckedFunctionDecl +from guppylang_internals.definition.function import CheckedFunctionDef, RawFunctionDef from guppylang_internals.definition.value import CompiledCallableDef from guppylang_internals.diagnostic import Error, Note -from guppylang_internals.engine import ENGINE -from guppylang_internals.error import GuppyError, pretty_errors +from guppylang_internals.engine import DEF_STORE, ENGINE +from guppylang_internals.error import GuppyError, InternalGuppyError, pretty_errors from guppylang_internals.span import Span, to_span from guppylang_internals.tracing.object import TracingDefMixin from guppylang_internals.tracing.util import hide_trace @@ -28,9 +31,6 @@ from guppylang.emulator import EmulatorBuilder, EmulatorInstance from guppylang.emulator.exceptions import EmulatorBuildError -if TYPE_CHECKING: - import ast - __all__ = ( "GuppyDefinition", "GuppyFunctionDefinition", @@ -230,6 +230,60 @@ def check(self) -> None: """Type-check all definitions Guppy definition.""" ENGINE.check(self.members) + def stubs(self) -> dict[str, str]: + stub_asts_by_module: dict[str, list[ast.stmt]] = {} + for member in self.members: + checked_def = ENGINE.get_checked(member) + match checked_def: + case CheckedFunctionDef(): + if checked_def.module is None: + raise InternalGuppyError( + "Checked definition has no associated module, cannot " + "generate stub!" + ) + stub_asts_by_module.setdefault(checked_def.module, []).append( + checked_def.stub() + ) + case CheckedFunctionDecl(): + if checked_def.module is None: + raise InternalGuppyError( + "Checked definition has no associated module, cannot " + "generate stub!" + ) + stub_asts_by_module.setdefault(checked_def.module, []).append( + checked_def.stub() + ) + case _: + raise NotImplementedError( + f"Cannot yet generate stubs for definitions of type " + f"{type(checked_def)}!" + ) + + module_stubs: dict[str, str] = {} + for module_name, stub_asts in stub_asts_by_module.items(): + module = importlib.import_module(module_name) + imports = ( + DEF_STORE.sources.imports[module.__file__].dump_ast() + if module.__file__ is not None + else [] + ) + + module_ast = ast.Module( + [ + *( + [ast.Expr(ast.Constant(module.__doc__))] + if module.__doc__ + else [] + ), + *imports, + *stub_asts, + ], + type_ignores=[], + ) + module_stubs[module_name] = ast.unparse(module_ast) + + return module_stubs + @dataclass(frozen=True) class GuppyTypeVarDefinition(GuppyDefinition): diff --git a/ruff.toml b/ruff.toml index 45e732843..0f92f255f 100644 --- a/ruff.toml +++ b/ruff.toml @@ -7,6 +7,7 @@ extend-exclude = [ "tests/error", "tests/integration/test_poly_py312.py", "tests/integration/test_comptime_expr_py312.py", + "tests/integration/stub/*.pyi", "*.ipynb", ] diff --git a/tests/integration/stub/__init__.py b/tests/integration/stub/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/integration/stub/basic.py b/tests/integration/stub/basic.py new file mode 100644 index 000000000..dd0dd25d1 --- /dev/null +++ b/tests/integration/stub/basic.py @@ -0,0 +1,33 @@ +"""Basic tests for generating Guppy stubs.""" + +from guppylang import guppy + + +@guppy +def lib_func(x: int) -> int: + return x + + +@guppy(link_name="my.custom.link.name") +def lib_custom_link_name(x: int) -> int: + return x + + +@guppy(unitary=True) +def lib_unitary(x: int) -> int: + return x + + +@guppy(control=True, dagger=True) +def lib_multiple_modifiers(x: int) -> int: + return x + + +@guppy(unitary=True, link_name="my.custom.link.name") +def lib_multiple_kwargs(x: int) -> int: + return x + + +@guppy(max_qubits=2) +def lib_max_qubits(x: int) -> int: + return x diff --git a/tests/integration/stub/basic.pyi b/tests/integration/stub/basic.pyi new file mode 100644 index 000000000..ddb17731f --- /dev/null +++ b/tests/integration/stub/basic.pyi @@ -0,0 +1,26 @@ +"""Basic tests for generating Guppy stubs.""" +from guppylang import guppy + +@guppy.declare +def lib_func(x: int) -> int: + ... + +@guppy.declare(link_name='my.custom.link.name') +def lib_custom_link_name(x: int) -> int: + ... + +@guppy.declare(unitary=True) +def lib_unitary(x: int) -> int: + ... + +@guppy.declare(control=True, dagger=True) +def lib_multiple_modifiers(x: int) -> int: + ... + +@guppy.declare(unitary=True, link_name='my.custom.link.name') +def lib_multiple_kwargs(x: int) -> int: + ... + +@guppy.declare(max_qubits=2) +def lib_max_qubits(x: int) -> int: + ... \ No newline at end of file diff --git a/tests/integration/stub/declaration.py b/tests/integration/stub/declaration.py new file mode 100644 index 000000000..13d829066 --- /dev/null +++ b/tests/integration/stub/declaration.py @@ -0,0 +1,11 @@ +"""Tests that Guppy stubs can be generated for function declarations.""" + +from guppylang import guppy + + +@guppy.declare +def lib_decl(x: int) -> int: ... + + +@guppy.declare(link_name="my.custom.link.name") +def lib_decl_custom_link_name(x: int) -> int: ... diff --git a/tests/integration/stub/declaration.pyi b/tests/integration/stub/declaration.pyi new file mode 100644 index 000000000..49a32f98e --- /dev/null +++ b/tests/integration/stub/declaration.pyi @@ -0,0 +1,10 @@ +"""Tests that Guppy stubs can be generated for function declarations.""" +from guppylang import guppy + +@guppy.declare +def lib_decl(x: int) -> int: + ... + +@guppy.declare(link_name='my.custom.link.name') +def lib_decl_custom_link_name(x: int) -> int: + ... \ No newline at end of file diff --git a/tests/integration/stub/docstring.py b/tests/integration/stub/docstring.py new file mode 100644 index 000000000..97ebd7e5a --- /dev/null +++ b/tests/integration/stub/docstring.py @@ -0,0 +1,17 @@ +"""Tests for inclusion of docstrings in the stub files. This module-level docstring +should be included.""" + +from guppylang import guppy + + +@guppy +def lib_docstring(x: int) -> int: + """A docstring for this wonderful function, that is included in the stubs.""" + return x + + +@guppy +def lib_docstring_multiline(x: int) -> int: + """A much longer docstring for this other wonderful function, which we should + include in the stubs regardless of what is customary and what is not.""" + return x diff --git a/tests/integration/stub/docstring.pyi b/tests/integration/stub/docstring.pyi new file mode 100644 index 000000000..fff922988 --- /dev/null +++ b/tests/integration/stub/docstring.pyi @@ -0,0 +1,14 @@ +"""Tests for inclusion of docstrings in the stub files. This module-level docstring +should be included.""" +from guppylang import guppy + +@guppy.declare +def lib_docstring(x: int) -> int: + """A docstring for this wonderful function, that is included in the stubs.""" + ... + +@guppy.declare +def lib_docstring_multiline(x: int) -> int: + """A much longer docstring for this other wonderful function, which we should + include in the stubs regardless of what is customary and what is not.""" + ... \ No newline at end of file diff --git a/tests/integration/stub/unused_imports.py b/tests/integration/stub/unused_imports.py new file mode 100644 index 000000000..1ec8533c2 --- /dev/null +++ b/tests/integration/stub/unused_imports.py @@ -0,0 +1,44 @@ +"""Whether imports are correctly marked as used / unused by stubs.""" + +from typing import TYPE_CHECKING + +from guppylang import guppy, comptime +from ast import Module # noqa: F401 +from guppylang.std.quantum import qubit, discard +from guppylang.std.array import array, array_swap, frozenarray # noqa: F401 +from guppylang.std.angles import angle + +if TYPE_CHECKING: + from guppylang.std.num import nat + + +@guppy +def lib_func(x: int) -> int: + return x + + +@guppy +def lib_func_using_import_in_body(x: int) -> int: + discard(qubit()) + + return x + + +@guppy +def lib_func_using_import_in_args_plain(x: angle) -> None: + pass + + +@guppy +def lib_func_using_import_in_args_string(x: "nat") -> None: + pass + + +@guppy +def lib_func_using_import_in_return_plain() -> array[int, 3]: + return array(1, 2, 3) + + +@guppy +def lib_func_using_import_in_return_string() -> "frozenarray[int, 3]": + return comptime(list(range(3))) diff --git a/tests/integration/stub/unused_imports.pyi b/tests/integration/stub/unused_imports.pyi new file mode 100644 index 000000000..2f6c7d1ee --- /dev/null +++ b/tests/integration/stub/unused_imports.pyi @@ -0,0 +1,29 @@ +"""Whether imports are correctly marked as used / unused by stubs.""" +from guppylang import guppy +from guppylang.std.array import array, frozenarray +from guppylang.std.angles import angle +from guppylang.std.num import nat + +@guppy.declare +def lib_func(x: int) -> int: + ... + +@guppy.declare +def lib_func_using_import_in_body(x: int) -> int: + ... + +@guppy.declare +def lib_func_using_import_in_args_plain(x: angle) -> None: + ... + +@guppy.declare +def lib_func_using_import_in_args_string(x: 'nat') -> None: + ... + +@guppy.declare +def lib_func_using_import_in_return_plain() -> array[int, 3]: + ... + +@guppy.declare +def lib_func_using_import_in_return_string() -> 'frozenarray[int, 3]': + ... \ No newline at end of file diff --git a/tests/integration/stub/unused_imports_expressions.py b/tests/integration/stub/unused_imports_expressions.py new file mode 100644 index 000000000..9607724a7 --- /dev/null +++ b/tests/integration/stub/unused_imports_expressions.py @@ -0,0 +1,23 @@ +"""More complicated expressions of imports that should not affect retaining them in +stubs.""" + +from guppylang import guppy, comptime +from guppylang.std.quantum import qubit, discard +from guppylang.std.num import nat +from guppylang.std.lang import owned + + +# Currently ignored, since we do not support type unions yet +# @guppy +def _lib_func_or_none(x: int | None) -> None: + pass + + +@guppy +def lib_func_comptime_arg(x: nat @ comptime) -> None: + pass + + +@guppy +def lib_func_owned_arg(x: qubit @ owned) -> None: + discard(x) diff --git a/tests/integration/stub/unused_imports_expressions.pyi b/tests/integration/stub/unused_imports_expressions.pyi new file mode 100644 index 000000000..f4cb547e6 --- /dev/null +++ b/tests/integration/stub/unused_imports_expressions.pyi @@ -0,0 +1,14 @@ +"""More complicated expressions of imports that should not affect retaining them in +stubs.""" +from guppylang import guppy, comptime +from guppylang.std.quantum import qubit +from guppylang.std.num import nat +from guppylang.std.lang import owned + +@guppy.declare +def lib_func_comptime_arg(x: nat @ comptime) -> None: + ... + +@guppy.declare +def lib_func_owned_arg(x: qubit @ owned) -> None: + ... \ No newline at end of file diff --git a/tests/integration/test_hugr_name.py b/tests/integration/test_link_name.py similarity index 76% rename from tests/integration/test_hugr_name.py rename to tests/integration/test_link_name.py index 872e83874..6a016e992 100644 --- a/tests/integration/test_hugr_name.py +++ b/tests/integration/test_link_name.py @@ -35,22 +35,22 @@ def _func_names_excluding_main(package: Package, qualifier: str) -> set[str]: return func_names -def test_func_hugr_name_annotated(): - """Asserts that annotated function `hugr_name`s are passed to the HUGR nodes.""" +def test_func_link_name_annotated(): + """Asserts that annotated function `link_name`s are passed to the HUGR nodes.""" - @guppy(hugr_name="some.qualified.name") + @guppy(link_name="some.qualified.name") def main_def() -> None: pass - @guppy.declare(hugr_name="some.other.qualified.name") + @guppy.declare(link_name="some.other.qualified.name") def main_dec() -> None: ... assert _func_names(main_def.compile()) == {"some.qualified.name"} assert _func_names(main_dec.compile()) == {"some.other.qualified.name"} -def test_func_hugr_name_inferred(qualifier): - """Asserts that inferred function `hugr_name`s are passed to the HUGR nodes.""" +def test_func_link_name_inferred(qualifier): + """Asserts that inferred function `link_name`s are passed to the HUGR nodes.""" @guppy def crazy_def() -> None: @@ -63,16 +63,16 @@ def crazy_dec() -> None: ... assert _func_names(crazy_dec.compile()) == {f"{qualifier}..crazy_dec"} -def test_struct_member_hugr_name_annotated(qualifier): - """Asserts that inferred function `hugr_name`s are passed to the HUGR nodes.""" +def test_struct_member_link_name_annotated(qualifier): + """Asserts that inferred function `link_name`s are passed to the HUGR nodes.""" @guppy.struct class MySuperbStruct: - @guppy(hugr_name="totally_qualified_override_name") + @guppy(link_name="totally_qualified_override_name") def some_name_that_is_crazy(self) -> None: pass - @guppy.declare(hugr_name="superbly_qualified_override_name") + @guppy.declare(link_name="superbly_qualified_override_name") def some_other_name_that_is_crazy(self) -> None: ... @guppy @@ -88,8 +88,8 @@ def main() -> None: } -def test_struct_member_hugr_name_inferred(qualifier): - """Asserts that inferred function `hugr_name`s are passed to the HUGR nodes.""" +def test_struct_member_link_name_inferred(qualifier): + """Asserts that inferred function `link_name`s are passed to the HUGR nodes.""" @guppy.struct class MySuperbStruct: @@ -113,17 +113,17 @@ def main() -> None: } -def test_struct_member_hugr_name_supported(qualifier): - """Asserts that function `hugr_name`s of struct members that are derived through - providing a `hugr_name` to the struct are correctly inferred.""" +def test_struct_member_link_name_supported(qualifier): + """Asserts that function `link_name`s of struct members that are derived through + providing a `link_name` to the struct are correctly inferred.""" - @guppy.struct(hugr_name="my.superb.qualifier") + @guppy.struct(link_name="my.superb.qualifier") class MySuperbStruct: @guppy def some_name_that_is_crazy(self) -> None: pass - @guppy(hugr_name="the.override.still.works") + @guppy(link_name="the.override.still.works") def some_other_name_that_is_crazy(self) -> None: pass @@ -172,8 +172,8 @@ def main() -> None: file_level_defn() assert _func_names_excluding_main(main.compile(), qualifier) == { - "tests.integration.test_hugr_name.file_level_defn", - "tests.integration.test_hugr_name.file_level_decl", - "tests.integration.test_hugr_name.FileLevelStruct.crazy_name_defn", - "tests.integration.test_hugr_name.FileLevelStruct.crazy_name_decl", + "tests.integration.test_link_name.file_level_defn", + "tests.integration.test_link_name.file_level_decl", + "tests.integration.test_link_name.FileLevelStruct.crazy_name_defn", + "tests.integration.test_link_name.FileLevelStruct.crazy_name_decl", } diff --git a/tests/integration/test_stub.py b/tests/integration/test_stub.py new file mode 100644 index 000000000..5631a335b --- /dev/null +++ b/tests/integration/test_stub.py @@ -0,0 +1,72 @@ +import importlib +import importlib.util +import importlib.machinery +import pathlib +import sys +import types +from pathlib import Path + +import pytest + +from guppylang import guppy + + +def import_from_path( + module_name: str, file_path: Path +) -> tuple[types.ModuleType, importlib.machinery.ModuleSpec]: + loader = importlib.machinery.SourceFileLoader(module_name, str(file_path)) + spec = importlib.util.spec_from_file_location( + module_name, str(file_path), loader=loader + ) + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + return module, spec + + +def _run_stub_test(file, snapshot): + file = pathlib.Path(file) + + module_name = f"tests.integration.stub.{file.stem}" + module = importlib.import_module(module_name) + # Collect top level functions defined in the module + library = guppy.library( + *[item for name, item in module.__dict__.items() if name.startswith("lib_")], + ) + library.check() + stubs = library.stubs() + + assert module_name in stubs, f"Expected stubs to be generated for {module_name}" + assert len(stubs) == 1, ( + "Expected exactly one stub module to be generated, but got: " + + ", ".join(stubs.keys()) + ) + + stub_file = file.with_suffix(".pyi") + snapshot.snapshot_dir = str(file.parent) + snapshot.assert_match(stubs[module_name], stub_file.name) + + # Test whether stubs can be imported, e.g. to test whether all required names are + # defined in import statements. + module, spec = import_from_path(file.stem, stub_file.resolve()) + try: + spec.loader.exec_module(module) + except Exception as e: # noqa: BLE001 + pytest.fail( + f"Type stubs were generated, but are bad! Exception during import: {e!s}" + ) + + +path = pathlib.Path(__file__).parent.resolve() / "stub" +files = [ + x + for x in path.iterdir() + if x.is_file() and x.suffix == ".py" and x.name != "__init__.py" +] + +# Turn paths into strings, otherwise pytest doesn't display the names +files = [str(f) for f in files] + + +@pytest.mark.parametrize("file", files) +def test_stubs(file, snapshot): + _run_stub_test(file, snapshot)