Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions guppylang-internals/src/guppylang_internals/tracing/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@
from guppylang_internals.checker.errors.type_errors import TypeMismatchError
from guppylang_internals.compiler.core import CompilerContext, DFContainer
from guppylang_internals.compiler.expr_compiler import ExprCompiler
from guppylang_internals.definition.overloaded import OverloadedFunctionDef
from guppylang_internals.definition.value import CallableDef
from guppylang_internals.diagnostic import Error
from guppylang_internals.error import GuppyComptimeError, GuppyError, exception_hook
from guppylang_internals.nodes import PlaceNode
from guppylang_internals.nodes import GlobalCall, PlaceNode
from guppylang_internals.tracing.builtins_mock import mock_builtins
from guppylang_internals.tracing.object import GuppyObject
from guppylang_internals.tracing.state import (
Expand Down Expand Up @@ -173,8 +174,15 @@ def trace_call(func: CallableDef, *args: Any) -> Any:
# Update inouts
# If the input types of the function aren't known, we can't check this.
# This is the case for functions with a custom checker and no type annotations.
if len(func.ty.inputs) != 0:
for inp, arg, var in zip(func.ty.inputs, args, arg_vars, strict=True):
# For overloaded functions, func.ty is a dummy type with no inputs, so we
# resolve the actual variant from the call node to get the correct input flags.
func_inputs = func.ty.inputs
if isinstance(func, OverloadedFunctionDef) and isinstance(call_node, GlobalCall):
resolved_def = state.globals[call_node.def_id]
if isinstance(resolved_def, CallableDef):
func_inputs = resolved_def.ty.inputs
if len(func_inputs) != 0:
for inp, arg, var in zip(func_inputs, args, arg_vars, strict=True):
if InputFlags.Inout in inp.flags:
# Note that `inp.ty` could refer to bound variables in the function
# signature. Instead, make sure to use `var.ty` which will always be a
Expand Down
Loading
Loading