Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions guppylang-internals/src/guppylang_internals/ast_util.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import ast
import textwrap
from collections.abc import Callable, Mapping, Sequence
from copy import deepcopy
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar, cast

Expand Down Expand Up @@ -418,3 +419,55 @@ def parse_source(source_lines: list[str], line_offset: int) -> tuple[str, ast.AS
else:
node = ast.parse(source).body[0]
return source, node, line_offset


ImportStmt = ast.Import | ast.ImportFrom


class ImportMap:
"""Maps imported names to their fully qualified module paths."""

imports_by_name: dict[str, tuple[ImportStmt, ast.alias]]
used_names: set[str]

def __init__(self) -> None:
self.imports_by_name = {}
self.used_names = set()

def register_import(self, name: str, stmt: ImportStmt, alias: ast.alias) -> None:
self.imports_by_name[name] = (stmt, alias)

def has(self, name: str) -> bool:
return name in self.imports_by_name

def use(self, name: str) -> None:
"""Marks a name as used, causing the ast generated by `dump_ast` to include an
import statement for it."""
# if not self.has(name):
# raise KeyError(
# f"Name `{name}` is not registered in this import map. "
# f"Perhaps it is not a top level import?"
# )

return self.used_names.add(name)

def dump_ast(self) -> list[ImportStmt]:
"""Returns a list of AST statements representing the imports registered in
this map."""
used_aliases_by_original_stmt = dict[ImportStmt, list[ast.alias]]()
for name, (stmt, alias) in self.imports_by_name.items():
if name not in self.used_names:
continue

if stmt not in used_aliases_by_original_stmt:
used_aliases_by_original_stmt[stmt] = []

used_aliases_by_original_stmt[stmt].append(alias)

stmts = []
for stmt, used_aliases in used_aliases_by_original_stmt.items():
new_stmt = deepcopy(stmt)
new_stmt.names = used_aliases
stmts.append(new_stmt)

return stmts
Original file line number Diff line number Diff line change
Expand Up @@ -224,8 +224,8 @@ def check_nested_func_def(
func_ty,
None,
# Even though global, this function will be private to the built hugr,
# so the hugr name does not really matter.
hugr_name=func_def.name,
# so the link name does not really matter.
link_name=func_def.name,
)
DEF_STORE.register_def(func, None)
ENGINE.parsed[def_id] = func
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,13 @@
from hugr.build import function as hf
from hugr.build.dfg import DefinitionBuilder, OpVar

from guppylang_internals.ast_util import AstNode, has_empty_body, with_loc, with_type
from guppylang_internals.ast_util import (
AstNode,
ImportMap,
has_empty_body,
with_loc,
with_type,
)
from guppylang_internals.checker.core import Context, Globals
from guppylang_internals.checker.expr_checker import check_call, synthesize_call
from guppylang_internals.checker.func_checker import check_signature
Expand All @@ -19,6 +25,7 @@
from guppylang_internals.definition.function import (
PyFunc,
compile_call,
generate_stub_from_def,
load_with_args,
parse_py_func,
)
Expand Down Expand Up @@ -69,25 +76,25 @@ class RawFunctionDecl(ParsableDef):

unitary_flags: UnitaryFlags = field(default=UnitaryFlags.NoFlags, kw_only=True)

hugr_name: InitVar[str | None] = field(default=None, kw_only=True)
_user_set_hugr_name: str | None = field(default=None, init=False)
link_name: InitVar[str | None] = field(default=None, kw_only=True)
_user_set_link_name: str | None = field(default=None, init=False)

def __post_init__(self, hugr_name: str | None) -> None:
object.__setattr__(self, "_user_set_hugr_name", hugr_name)
def __post_init__(self, link_name: str | None) -> None:
object.__setattr__(self, "_user_set_link_name", link_name)

def parse(self, globals: Globals, sources: SourceMap) -> "CheckedFunctionDecl":
"""Parses and checks the user-provided signature of the function."""
func_ast, docstring = parse_py_func(self.python_func, sources)
ty = check_signature(
func_ast, globals, self.id, unitary_flags=self.unitary_flags
)
hugr_name = f"{self.python_func.__module__}.{self.python_func.__qualname__}"
if self._user_set_hugr_name is not None:
hugr_name = self._user_set_hugr_name
link_name = f"{self.python_func.__module__}.{self.python_func.__qualname__}"
if self._user_set_link_name is not None:
link_name = self._user_set_link_name
elif (parent_ty_id := DEF_STORE.impl_parents.get(self.id)) is not None:
parent = ENGINE.get_parsed(parent_ty_id)
if isinstance(parent, ParsedStructDef):
hugr_name = f"{parent.hugr_name_prefix}.{self.python_func.__name__}"
link_name = f"{parent.link_name_prefix}.{self.python_func.__name__}"

if not has_empty_body(func_ast):
raise GuppyError(BodyNotEmptyError(func_ast.body[0], self.name))
Expand All @@ -100,9 +107,13 @@ def parse(self, globals: Globals, sources: SourceMap) -> "CheckedFunctionDecl":
func_ast,
ty,
docstring,
hugr_name,
link_name,
module=self.python_func.__module__,
)

def generate_guppy_declare_decorator(self, import_map: ImportMap) -> ast.expr:
raise NotImplementedError("Must be implemented by a subclass!")


@dataclass(frozen=True)
class CheckedFunctionDecl(CompilableDef, CallableDef):
Expand All @@ -113,7 +124,8 @@ class CheckedFunctionDecl(CompilableDef, CallableDef):

defined_at: ast.FunctionDef
docstring: str | None
hugr_name: str
link_name: str
module: str | None = field(default=None, kw_only=True)

def check_call(
self, args: list[ast.expr], ty: Type, node: AstNode, ctx: Context
Expand Down Expand Up @@ -142,17 +154,25 @@ def compile_outer(
)
module: hf.Module = module

node = module.declare_function(self.hugr_name, self.ty.to_hugr_poly(ctx))
node = module.declare_function(self.link_name, self.ty.to_hugr_poly(ctx))
return CompiledFunctionDecl(
self.id,
self.name,
self.defined_at,
self.ty,
self.docstring,
self.hugr_name,
self.link_name,
node,
module=self.module,
)

def stub(self) -> ast.FunctionDef:
"""Generates a stub function declaration with an empty body."""
raw_def = DEF_STORE.raw_defs[self.id]
assert isinstance(raw_def, RawFunctionDecl)

return generate_stub_from_def(raw_def, self.defined_at)


@dataclass(frozen=True)
class CompiledFunctionDecl(
Expand Down
Loading
Loading