diff --git a/guppylang-internals/src/guppylang_internals/checker/core.py b/guppylang-internals/src/guppylang_internals/checker/core.py index 4056b0daa..dc7250eb4 100644 --- a/guppylang-internals/src/guppylang_internals/checker/core.py +++ b/guppylang-internals/src/guppylang_internals/checker/core.py @@ -12,7 +12,6 @@ NamedTuple, TypeAlias, TypeVar, - cast, overload, ) @@ -25,29 +24,12 @@ Definition, ParsedDef, ) -from guppylang_internals.definition.ty import TypeDef -from guppylang_internals.definition.value import CallableDef from guppylang_internals.engine import BUILTIN_DEFS, DEF_STORE, ENGINE from guppylang_internals.error import InternalGuppyError -from guppylang_internals.tys.builtin import ( - callable_type_def, - float_type_def, - int_type_def, - nat_type_def, - none_type_def, - tuple_type_def, -) from guppylang_internals.tys.param import Parameter from guppylang_internals.tys.ty import ( - BoundTypeVar, - ExistentialTypeVar, - FunctionType, InputFlags, - NoneType, - NumericType, - OpaqueType, StructType, - TupleType, Type, ) @@ -350,48 +332,6 @@ def builtin_defs() -> dict[str, Definition]: if isinstance(val, GuppyDefinition) } - def get_instance_func(self, ty: Type | TypeDef, name: str) -> CallableDef | None: - """Looks up an instance function with a given name for a type. - - Returns `None` if the name doesn't exist or isn't a function. - """ - type_defn: TypeDef - match ty: - case TypeDef() as type_defn: - pass - case BoundTypeVar() | ExistentialTypeVar(): - return None - case NumericType(kind): - match kind: - case NumericType.Kind.Nat: - type_defn = nat_type_def - case NumericType.Kind.Int: - type_defn = int_type_def - case NumericType.Kind.Float: - type_defn = float_type_def - case kind: - return assert_never(kind) - case FunctionType(): - type_defn = callable_type_def - case OpaqueType() as ty: - type_defn = ty.defn - case StructType() as ty: - type_defn = ty.defn - case TupleType(): - type_defn = tuple_type_def - case NoneType(): - type_defn = none_type_def - case _: - return assert_never(ty) - - type_defn = cast(TypeDef, ENGINE.get_checked(type_defn.id)) - if type_defn.id in DEF_STORE.impls and name in DEF_STORE.impls[type_defn.id]: - def_id = DEF_STORE.impls[type_defn.id][name] - defn = ENGINE.get_parsed(def_id) - if isinstance(defn, CallableDef): - return defn - return None - def __contains__(self, item: DefId | str) -> bool: match item: case DefId() as def_id: diff --git a/guppylang-internals/src/guppylang_internals/checker/expr_checker.py b/guppylang-internals/src/guppylang_internals/checker/expr_checker.py index bdf11fdef..8cfbd7982 100644 --- a/guppylang-internals/src/guppylang_internals/checker/expr_checker.py +++ b/guppylang-internals/src/guppylang_internals/checker/expr_checker.py @@ -88,6 +88,7 @@ from guppylang_internals.definition.parameter import ParamDef from guppylang_internals.definition.ty import TypeDef from guppylang_internals.definition.value import CallableDef, ValueDef +from guppylang_internals.engine import ENGINE from guppylang_internals.error import ( GuppyError, GuppyTypeError, @@ -343,7 +344,7 @@ def visit_Call(self, node: ast.Call, ty: Type) -> tuple[ast.expr, Subst]: TensorCall(func=node.func, args=processed_args, tensor_ty=tensor_ty), ), subst - elif callee := self.ctx.globals.get_instance_func(func_ty, "__call__"): + elif callee := ENGINE.get_instance_func(func_ty, "__call__"): return callee.check_call(node.args, ty, node, self.ctx) else: raise GuppyTypeError(NotCallableError(node.func, func_ty)) @@ -455,7 +456,7 @@ def _check_global( case ValueDef() as defn: return with_loc(node, GlobalName(id=name, def_id=defn.id)), defn.ty # For types, we return their `__new__` constructor - case TypeDef() as defn if constr := self.ctx.globals.get_instance_func( + case TypeDef() as defn if constr := ENGINE.get_instance_func( defn, "__new__" ): return with_loc(node, GlobalName(id=name, def_id=constr.id)), constr.ty @@ -513,7 +514,7 @@ def visit_Attribute(self, node: ast.Attribute) -> tuple[ast.expr, Type]: # you loose access to all fields besides `a`). expr = FieldAccessAndDrop(value=node.value, struct_ty=ty, field=field) return with_loc(node, expr), field.ty - elif func := self.ctx.globals.get_instance_func(ty, node.attr): + elif func := ENGINE.get_instance_func(ty, node.attr): name = with_type( func.ty, with_loc(node, GlobalName(id=func.name, def_id=func.id)) ) @@ -581,7 +582,7 @@ def visit_UnaryOp(self, node: ast.UnaryOp) -> tuple[ast.expr, Type]: # Check all other unary expressions by calling out to instance dunder methods op, display_name = unary_table[node.op.__class__] - func = self.ctx.globals.get_instance_func(op_ty, op) + func = ENGINE.get_instance_func(op_ty, op) if func is None: raise GuppyTypeError( UnaryOperatorNotDefinedError(node.operand, op_ty, display_name) @@ -602,11 +603,11 @@ def _synthesize_binary( left_expr, left_ty = self.synthesize(left_expr) right_expr, right_ty = self.synthesize(right_expr) - if func := self.ctx.globals.get_instance_func(left_ty, lop): + if func := ENGINE.get_instance_func(left_ty, lop): with suppress(GuppyError): return func.synthesize_call([left_expr, right_expr], node, self.ctx) - if func := self.ctx.globals.get_instance_func(right_ty, rop): + if func := ENGINE.get_instance_func(right_ty, rop): with suppress(GuppyError): return func.synthesize_call([right_expr, left_expr], node, self.ctx) @@ -634,7 +635,7 @@ def synthesize_instance_func( given expected signature. """ node, ty = self.synthesize(node) - func = self.ctx.globals.get_instance_func(ty, func_name) + func = ENGINE.get_instance_func(ty, func_name) if func is None: err = BadProtocolError(node, ty, description) if give_reason and exp_sig is not None: @@ -770,7 +771,7 @@ def visit_Call(self, node: ast.Call) -> tuple[ast.expr, Type]: node, TensorCall(func=node.func, args=args, tensor_ty=tensor_ty) ), return_ty - elif f := self.ctx.globals.get_instance_func(ty, "__call__"): + elif f := ENGINE.get_instance_func(ty, "__call__"): return f.synthesize_call(node.args, node, self.ctx) else: raise GuppyTypeError(NotCallableError(node.func, ty)) @@ -915,7 +916,7 @@ def try_coerce_to( return None # Ordering on `NumericType.Kind` defines the coercion relation if act.kind < exp.kind: - f = ctx.globals.get_instance_func(act, f"__{exp.kind.name.lower()}__") + f = ENGINE.get_instance_func(act, f"__{exp.kind.name.lower()}__") assert f is not None node, subst = f.check_call([node], exp, node, ctx) assert len(subst) == 0, "Coercion methods are not generic" diff --git a/guppylang-internals/src/guppylang_internals/checker/stmt_checker.py b/guppylang-internals/src/guppylang_internals/checker/stmt_checker.py index a46ba7877..d5347f834 100644 --- a/guppylang-internals/src/guppylang_internals/checker/stmt_checker.py +++ b/guppylang-internals/src/guppylang_internals/checker/stmt_checker.py @@ -53,6 +53,7 @@ check_place_assignable, synthesize_comprehension, ) +from guppylang_internals.engine import ENGINE from guppylang_internals.error import GuppyError, GuppyTypeError, InternalGuppyError from guppylang_internals.nodes import ( AnyUnpack, @@ -326,7 +327,7 @@ def _check_unpackable( ) raise GuppyError(err) - elif self.ctx.globals.get_instance_func(ty, "__iter__"): + elif ENGINE.get_instance_func(ty, "__iter__"): size = check_iter_unpack_has_static_size(expr, self.ctx) # Create a dummy variable and assign the expression to it. This helps us to # wire it up correctly during Hugr generation. diff --git a/guppylang-internals/src/guppylang_internals/compiler/core.py b/guppylang-internals/src/guppylang_internals/compiler/core.py index 2fb76a68f..32e27dc7e 100644 --- a/guppylang-internals/src/guppylang_internals/compiler/core.py +++ b/guppylang-internals/src/guppylang_internals/compiler/core.py @@ -266,7 +266,7 @@ def build_compiled_instance_func( """ from guppylang_internals.engine import ENGINE - parsed_func = self.checked_globals.get_instance_func(ty, name) + parsed_func = ENGINE.get_instance_func(ty, name) if parsed_func is None: return None checked_func = ENGINE.get_checked(parsed_func.id) diff --git a/guppylang-internals/src/guppylang_internals/engine.py b/guppylang-internals/src/guppylang_internals/engine.py index 214001742..5f74aabd6 100644 --- a/guppylang-internals/src/guppylang_internals/engine.py +++ b/guppylang-internals/src/guppylang_internals/engine.py @@ -1,7 +1,7 @@ from collections import defaultdict from enum import Enum from types import FrameType -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, assert_never, cast import hugr.build.function as hf import hugr.std.collections.array @@ -25,6 +25,7 @@ ) from guppylang_internals.definition.ty import TypeDef from guppylang_internals.definition.value import ( + CallableDef, CompiledCallableDef, CompiledHugrNodeDef, ) @@ -46,7 +47,17 @@ string_type_def, tuple_type_def, ) -from guppylang_internals.tys.ty import FunctionType +from guppylang_internals.tys.ty import ( + BoundTypeVar, + ExistentialTypeVar, + FunctionType, + NoneType, + NumericType, + OpaqueType, + StructType, + TupleType, + Type, +) if TYPE_CHECKING: from guppylang_internals.compiler.core import MonoDefId @@ -308,5 +319,48 @@ def compile(self, id: DefId) -> ModulePointer: } return ModulePointer(Package(modules=[graph.hugr], extensions=extensions), 0) + @pretty_errors + def get_instance_func(self, ty: Type | TypeDef, name: str) -> CallableDef | None: + """Looks up an instance function with a given name for a type. + + Returns `None` if the name doesn't exist or isn't a function. + """ + type_defn: TypeDef + match ty: + case TypeDef() as type_defn: + pass + case BoundTypeVar() | ExistentialTypeVar(): + return None + case NumericType(kind): + match kind: + case NumericType.Kind.Nat: + type_defn = nat_type_def + case NumericType.Kind.Int: + type_defn = int_type_def + case NumericType.Kind.Float: + type_defn = float_type_def + case kind: + return assert_never(kind) + case FunctionType(): + type_defn = callable_type_def + case OpaqueType() as ty: + type_defn = ty.defn + case StructType() as ty: + type_defn = ty.defn + case TupleType(): + type_defn = tuple_type_def + case NoneType(): + type_defn = none_type_def + case _: + return assert_never(ty) + + type_defn = cast(TypeDef, self.get_checked(type_defn.id)) + if type_defn.id in DEF_STORE.impls and name in DEF_STORE.impls[type_defn.id]: + def_id = DEF_STORE.impls[type_defn.id][name] + defn = self.get_parsed(def_id) + if isinstance(defn, CallableDef): + return defn + return None + ENGINE: CompilationEngine = CompilationEngine() diff --git a/guppylang-internals/src/guppylang_internals/std/_internal/checker.py b/guppylang-internals/src/guppylang_internals/std/_internal/checker.py index 89ca2149a..282e4f328 100644 --- a/guppylang-internals/src/guppylang_internals/std/_internal/checker.py +++ b/guppylang-internals/src/guppylang_internals/std/_internal/checker.py @@ -24,6 +24,7 @@ CustomCallChecker, ) from guppylang_internals.diagnostic import Error, Note +from guppylang_internals.engine import ENGINE from guppylang_internals.error import GuppyError, GuppyTypeError, InternalGuppyError from guppylang_internals.nodes import ( BarrierExpr, @@ -79,7 +80,7 @@ def parse_name(self) -> str: def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, Type]: [self_arg, other_arg] = args self_arg, self_ty = ExprSynthesizer(self.ctx).synthesize(self_arg) - f = self.ctx.globals.get_instance_func(self_ty, self.parse_name()) + f = ENGINE.get_instance_func(self_ty, self.parse_name()) assert f is not None return f.synthesize_call([other_arg, self_arg], self.node, self.ctx) @@ -135,7 +136,7 @@ def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, Type]: arg, ty = ExprSynthesizer(self.ctx).synthesize(arg) is_callable = ( isinstance(ty, FunctionType) - or self.ctx.globals.get_instance_func(ty, "__call__") is not None + or ENGINE.get_instance_func(ty, "__call__") is not None ) const = with_loc(self.node, ast.Constant(value=is_callable)) return const, bool_type() @@ -402,7 +403,7 @@ def to_sized_iter( ) -> tuple[ast.expr, Type]: """Adds a static size annotation to an iterator.""" sized_iter_ty = sized_iter_type(range_ty, size) - make_sized_iter = ctx.globals.get_instance_func(sized_iter_ty, "__new__") + make_sized_iter = ENGINE.get_instance_func(sized_iter_ty, "__new__") assert make_sized_iter is not None sized_iter, _ = make_sized_iter.check_call([iterator], sized_iter_ty, iterator, ctx) return sized_iter, sized_iter_ty diff --git a/guppylang-internals/src/guppylang_internals/tracing/object.py b/guppylang-internals/src/guppylang_internals/tracing/object.py index e4c9c4971..ef224e862 100644 --- a/guppylang-internals/src/guppylang_internals/tracing/object.py +++ b/guppylang-internals/src/guppylang_internals/tracing/object.py @@ -346,7 +346,7 @@ def __init__(self, ty: Type, wire: Wire, used: ObjectUse | None = None) -> None: def __getattr__(self, key: str) -> Any: # type: ignore[misc] # Guppy objects don't have fields (structs are treated separately below), so the # only attributes we have to worry about are methods. - func = get_tracing_state().globals.get_instance_func(self._ty, key) + func = ENGINE.get_instance_func(self._ty, key) if func is None: raise GuppyComptimeError( f"Expression of type `{self._ty}` has no attribute `{key}`" @@ -455,7 +455,7 @@ def __getattr__(self, key: str) -> Any: # type: ignore[misc] if key in self._field_values: return self._field_values[key] # Or a method - func = get_tracing_state().globals.get_instance_func(self._ty, key) + func = ENGINE.get_instance_func(self._ty, key) if func is None: err = f"Expression of type `{self._ty}` has no attribute `{key}`" raise AttributeError(err) diff --git a/tests/error/comptime_expr_errors/python_err.err b/tests/error/comptime_expr_errors/python_err.err index b0303645f..12b54b912 100644 --- a/tests/error/comptime_expr_errors/python_err.err +++ b/tests/error/comptime_expr_errors/python_err.err @@ -9,6 +9,8 @@ Traceback printed below: Traceback (most recent call last): File "", line 1, in + import sys;exec(eval(sys.stdin.readline())) + ^^^^^ ZeroDivisionError: division by zero Guppy compilation failed due to 1 previous error