Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 0 additions & 60 deletions guppylang-internals/src/guppylang_internals/checker/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
NamedTuple,
TypeAlias,
TypeVar,
cast,
overload,
)

Expand All @@ -25,29 +24,12 @@
Definition,
ParsedDef,
)
from guppylang_internals.definition.ty import TypeDef
from guppylang_internals.definition.value import CallableDef
from guppylang_internals.engine import BUILTIN_DEFS, DEF_STORE, ENGINE
from guppylang_internals.error import InternalGuppyError
from guppylang_internals.tys.builtin import (
callable_type_def,
float_type_def,
int_type_def,
nat_type_def,
none_type_def,
tuple_type_def,
)
from guppylang_internals.tys.param import Parameter
from guppylang_internals.tys.ty import (
BoundTypeVar,
ExistentialTypeVar,
FunctionType,
InputFlags,
NoneType,
NumericType,
OpaqueType,
StructType,
TupleType,
Type,
)

Expand Down Expand Up @@ -350,48 +332,6 @@ def builtin_defs() -> dict[str, Definition]:
if isinstance(val, GuppyDefinition)
}

def get_instance_func(self, ty: Type | TypeDef, name: str) -> CallableDef | None:
"""Looks up an instance function with a given name for a type.

Returns `None` if the name doesn't exist or isn't a function.
"""
type_defn: TypeDef
match ty:
case TypeDef() as type_defn:
pass
case BoundTypeVar() | ExistentialTypeVar():
return None
case NumericType(kind):
match kind:
case NumericType.Kind.Nat:
type_defn = nat_type_def
case NumericType.Kind.Int:
type_defn = int_type_def
case NumericType.Kind.Float:
type_defn = float_type_def
case kind:
return assert_never(kind)
case FunctionType():
type_defn = callable_type_def
case OpaqueType() as ty:
type_defn = ty.defn
case StructType() as ty:
type_defn = ty.defn
case TupleType():
type_defn = tuple_type_def
case NoneType():
type_defn = none_type_def
case _:
return assert_never(ty)

type_defn = cast(TypeDef, ENGINE.get_checked(type_defn.id))
if type_defn.id in DEF_STORE.impls and name in DEF_STORE.impls[type_defn.id]:
def_id = DEF_STORE.impls[type_defn.id][name]
defn = ENGINE.get_parsed(def_id)
if isinstance(defn, CallableDef):
return defn
return None

def __contains__(self, item: DefId | str) -> bool:
match item:
case DefId() as def_id:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@
from guppylang_internals.definition.parameter import ParamDef
from guppylang_internals.definition.ty import TypeDef
from guppylang_internals.definition.value import CallableDef, ValueDef
from guppylang_internals.engine import ENGINE
from guppylang_internals.error import (
GuppyError,
GuppyTypeError,
Expand Down Expand Up @@ -343,7 +344,7 @@ def visit_Call(self, node: ast.Call, ty: Type) -> tuple[ast.expr, Subst]:
TensorCall(func=node.func, args=processed_args, tensor_ty=tensor_ty),
), subst

elif callee := self.ctx.globals.get_instance_func(func_ty, "__call__"):
elif callee := ENGINE.get_instance_func(func_ty, "__call__"):
return callee.check_call(node.args, ty, node, self.ctx)
else:
raise GuppyTypeError(NotCallableError(node.func, func_ty))
Expand Down Expand Up @@ -455,7 +456,7 @@ def _check_global(
case ValueDef() as defn:
return with_loc(node, GlobalName(id=name, def_id=defn.id)), defn.ty
# For types, we return their `__new__` constructor
case TypeDef() as defn if constr := self.ctx.globals.get_instance_func(
case TypeDef() as defn if constr := ENGINE.get_instance_func(
defn, "__new__"
):
return with_loc(node, GlobalName(id=name, def_id=constr.id)), constr.ty
Expand Down Expand Up @@ -513,7 +514,7 @@ def visit_Attribute(self, node: ast.Attribute) -> tuple[ast.expr, Type]:
# you loose access to all fields besides `a`).
expr = FieldAccessAndDrop(value=node.value, struct_ty=ty, field=field)
return with_loc(node, expr), field.ty
elif func := self.ctx.globals.get_instance_func(ty, node.attr):
elif func := ENGINE.get_instance_func(ty, node.attr):
name = with_type(
func.ty, with_loc(node, GlobalName(id=func.name, def_id=func.id))
)
Expand Down Expand Up @@ -581,7 +582,7 @@ def visit_UnaryOp(self, node: ast.UnaryOp) -> tuple[ast.expr, Type]:

# Check all other unary expressions by calling out to instance dunder methods
op, display_name = unary_table[node.op.__class__]
func = self.ctx.globals.get_instance_func(op_ty, op)
func = ENGINE.get_instance_func(op_ty, op)
if func is None:
raise GuppyTypeError(
UnaryOperatorNotDefinedError(node.operand, op_ty, display_name)
Expand All @@ -602,11 +603,11 @@ def _synthesize_binary(
left_expr, left_ty = self.synthesize(left_expr)
right_expr, right_ty = self.synthesize(right_expr)

if func := self.ctx.globals.get_instance_func(left_ty, lop):
if func := ENGINE.get_instance_func(left_ty, lop):
with suppress(GuppyError):
return func.synthesize_call([left_expr, right_expr], node, self.ctx)

if func := self.ctx.globals.get_instance_func(right_ty, rop):
if func := ENGINE.get_instance_func(right_ty, rop):
with suppress(GuppyError):
return func.synthesize_call([right_expr, left_expr], node, self.ctx)

Expand Down Expand Up @@ -634,7 +635,7 @@ def synthesize_instance_func(
given expected signature.
"""
node, ty = self.synthesize(node)
func = self.ctx.globals.get_instance_func(ty, func_name)
func = ENGINE.get_instance_func(ty, func_name)
if func is None:
err = BadProtocolError(node, ty, description)
if give_reason and exp_sig is not None:
Expand Down Expand Up @@ -770,7 +771,7 @@ def visit_Call(self, node: ast.Call) -> tuple[ast.expr, Type]:
node, TensorCall(func=node.func, args=args, tensor_ty=tensor_ty)
), return_ty

elif f := self.ctx.globals.get_instance_func(ty, "__call__"):
elif f := ENGINE.get_instance_func(ty, "__call__"):
return f.synthesize_call(node.args, node, self.ctx)
else:
raise GuppyTypeError(NotCallableError(node.func, ty))
Expand Down Expand Up @@ -915,7 +916,7 @@ def try_coerce_to(
return None
# Ordering on `NumericType.Kind` defines the coercion relation
if act.kind < exp.kind:
f = ctx.globals.get_instance_func(act, f"__{exp.kind.name.lower()}__")
f = ENGINE.get_instance_func(act, f"__{exp.kind.name.lower()}__")
assert f is not None
node, subst = f.check_call([node], exp, node, ctx)
assert len(subst) == 0, "Coercion methods are not generic"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
check_place_assignable,
synthesize_comprehension,
)
from guppylang_internals.engine import ENGINE
from guppylang_internals.error import GuppyError, GuppyTypeError, InternalGuppyError
from guppylang_internals.nodes import (
AnyUnpack,
Expand Down Expand Up @@ -326,7 +327,7 @@ def _check_unpackable(
)
raise GuppyError(err)

elif self.ctx.globals.get_instance_func(ty, "__iter__"):
elif ENGINE.get_instance_func(ty, "__iter__"):
size = check_iter_unpack_has_static_size(expr, self.ctx)
# Create a dummy variable and assign the expression to it. This helps us to
# wire it up correctly during Hugr generation.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ def build_compiled_instance_func(
"""
from guppylang_internals.engine import ENGINE

parsed_func = self.checked_globals.get_instance_func(ty, name)
parsed_func = ENGINE.get_instance_func(ty, name)
if parsed_func is None:
return None
checked_func = ENGINE.get_checked(parsed_func.id)
Expand Down
58 changes: 56 additions & 2 deletions guppylang-internals/src/guppylang_internals/engine.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from collections import defaultdict
from enum import Enum
from types import FrameType
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, assert_never, cast

import hugr.build.function as hf
import hugr.std.collections.array
Expand All @@ -25,6 +25,7 @@
)
from guppylang_internals.definition.ty import TypeDef
from guppylang_internals.definition.value import (
CallableDef,
CompiledCallableDef,
CompiledHugrNodeDef,
)
Expand All @@ -46,7 +47,17 @@
string_type_def,
tuple_type_def,
)
from guppylang_internals.tys.ty import FunctionType
from guppylang_internals.tys.ty import (
BoundTypeVar,
ExistentialTypeVar,
FunctionType,
NoneType,
NumericType,
OpaqueType,
StructType,
TupleType,
Type,
)

if TYPE_CHECKING:
from guppylang_internals.compiler.core import MonoDefId
Expand Down Expand Up @@ -308,5 +319,48 @@ def compile(self, id: DefId) -> ModulePointer:
}
return ModulePointer(Package(modules=[graph.hugr], extensions=extensions), 0)

@pretty_errors
def get_instance_func(self, ty: Type | TypeDef, name: str) -> CallableDef | None:
"""Looks up an instance function with a given name for a type.

Returns `None` if the name doesn't exist or isn't a function.
"""
type_defn: TypeDef
match ty:
case TypeDef() as type_defn:
pass
case BoundTypeVar() | ExistentialTypeVar():
return None
case NumericType(kind):
match kind:
case NumericType.Kind.Nat:
type_defn = nat_type_def
case NumericType.Kind.Int:
type_defn = int_type_def
case NumericType.Kind.Float:
type_defn = float_type_def
case kind:
return assert_never(kind)
case FunctionType():
type_defn = callable_type_def
case OpaqueType() as ty:
type_defn = ty.defn
case StructType() as ty:
type_defn = ty.defn
case TupleType():
type_defn = tuple_type_def
case NoneType():
type_defn = none_type_def
case _:
return assert_never(ty)

type_defn = cast(TypeDef, self.get_checked(type_defn.id))
if type_defn.id in DEF_STORE.impls and name in DEF_STORE.impls[type_defn.id]:
def_id = DEF_STORE.impls[type_defn.id][name]
defn = self.get_parsed(def_id)
if isinstance(defn, CallableDef):
return defn
return None


ENGINE: CompilationEngine = CompilationEngine()
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
CustomCallChecker,
)
from guppylang_internals.diagnostic import Error, Note
from guppylang_internals.engine import ENGINE
from guppylang_internals.error import GuppyError, GuppyTypeError, InternalGuppyError
from guppylang_internals.nodes import (
BarrierExpr,
Expand Down Expand Up @@ -79,7 +80,7 @@ def parse_name(self) -> str:
def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, Type]:
[self_arg, other_arg] = args
self_arg, self_ty = ExprSynthesizer(self.ctx).synthesize(self_arg)
f = self.ctx.globals.get_instance_func(self_ty, self.parse_name())
f = ENGINE.get_instance_func(self_ty, self.parse_name())
assert f is not None
return f.synthesize_call([other_arg, self_arg], self.node, self.ctx)

Expand Down Expand Up @@ -135,7 +136,7 @@ def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, Type]:
arg, ty = ExprSynthesizer(self.ctx).synthesize(arg)
is_callable = (
isinstance(ty, FunctionType)
or self.ctx.globals.get_instance_func(ty, "__call__") is not None
or ENGINE.get_instance_func(ty, "__call__") is not None
)
const = with_loc(self.node, ast.Constant(value=is_callable))
return const, bool_type()
Expand Down Expand Up @@ -402,7 +403,7 @@ def to_sized_iter(
) -> tuple[ast.expr, Type]:
"""Adds a static size annotation to an iterator."""
sized_iter_ty = sized_iter_type(range_ty, size)
make_sized_iter = ctx.globals.get_instance_func(sized_iter_ty, "__new__")
make_sized_iter = ENGINE.get_instance_func(sized_iter_ty, "__new__")
assert make_sized_iter is not None
sized_iter, _ = make_sized_iter.check_call([iterator], sized_iter_ty, iterator, ctx)
return sized_iter, sized_iter_ty
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ def __init__(self, ty: Type, wire: Wire, used: ObjectUse | None = None) -> None:
def __getattr__(self, key: str) -> Any: # type: ignore[misc]
# Guppy objects don't have fields (structs are treated separately below), so the
# only attributes we have to worry about are methods.
func = get_tracing_state().globals.get_instance_func(self._ty, key)
func = ENGINE.get_instance_func(self._ty, key)
if func is None:
raise GuppyComptimeError(
f"Expression of type `{self._ty}` has no attribute `{key}`"
Expand Down Expand Up @@ -455,7 +455,7 @@ def __getattr__(self, key: str) -> Any: # type: ignore[misc]
if key in self._field_values:
return self._field_values[key]
# Or a method
func = get_tracing_state().globals.get_instance_func(self._ty, key)
func = ENGINE.get_instance_func(self._ty, key)
if func is None:
err = f"Expression of type `{self._ty}` has no attribute `{key}`"
raise AttributeError(err)
Expand Down
2 changes: 2 additions & 0 deletions tests/error/comptime_expr_errors/python_err.err
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ Traceback printed below:

Traceback (most recent call last):
File "<string>", line 1, in <module>
import sys;exec(eval(sys.stdin.readline()))
^^^^^
ZeroDivisionError: division by zero

Guppy compilation failed due to 1 previous error
Loading