Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions guppylang-internals/src/guppylang_internals/ast_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,3 +418,8 @@ def parse_source(source_lines: list[str], line_offset: int) -> tuple[str, ast.AS
else:
node = ast.parse(source).body[0]
return source, node, line_offset


def fake_call(name: str, args: list[ast.expr]) -> ast.Call:
"""Creates a fake call node with the given name and arguments."""
return ast.Call(func=ast.Name(id=name, ctx=ast.Load()), args=args, keywords=[])
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All the time we call fake_call, we wrap the result with with_loc, thus this comment is not right. However, since we always use with_loc, maybe can be useful to include the with_loc call in fake_call directly

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was wrong in the first comment. The issue underlined by copilot is right, however, I'm not sure if exceptions regarding node.func (like check_num_args) are possible on a fake_call function. @acl-cqc any thoughts?

Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
AstNode,
AstVisitor,
breaks_in_loop,
fake_call,
get_type,
get_type_opt,
return_nodes_in_ast,
Expand Down Expand Up @@ -721,7 +722,11 @@ def visit_UnaryOp(self, node: ast.UnaryOp) -> tuple[ast.expr, Type]:
return func.synthesize_call([node.operand], node, self.ctx)

def _synthesize_binary(
self, left_expr: ast.expr, right_expr: ast.expr, op: AstOp, node: ast.expr
self,
left_expr: ast.expr,
right_expr: ast.expr,
op: AstOp,
node: ast.BinOp | ast.Compare,
) -> tuple[ast.expr, Type]:
"""Helper method to compile binary operators by calling out to dunder methods.

Expand Down Expand Up @@ -1047,9 +1052,11 @@ def try_coerce_to(
return None
# Ordering on `NumericType.Kind` defines the coercion relation
if act.kind < exp.kind:
f = ENGINE.get_instance_func(act, f"__{exp.kind.name.lower()}__")
name = f"__{exp.kind.name.lower()}__"
f = ENGINE.get_instance_func(act, name)
assert f is not None
node, subst = f.check_call([node], exp, node, ctx)
call = with_loc(node, fake_call(name, [node]))
node, subst = f.check_call([node], exp, call, ctx)
assert len(subst) == 0, "Coercion methods are not generic"
return node
return None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ def check(self, type_args: Inst, globals: Globals) -> "CustomMonoFunctionDef":

@override
def check_call(
self, args: list[ast.expr], ty: Type, node: AstNode, ctx: Context
self, args: list[ast.expr], ty: Type, node: ast.Call, ctx: Context
) -> tuple[ast.expr, Subst]:
"""Checks the return type of a function call against a given type.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def check(self, type_args: Inst, globals: Globals) -> "CheckedFunctionDecl":

@override
def check_call(
self, args: list[ast.expr], ty: Type, node: AstNode, ctx: Context
self, args: list[ast.expr], ty: Type, node: ast.Call, ctx: Context
) -> tuple[ast.expr, Subst]:
"""Checks the return type of a function call against a given type."""
# Use default implementation from the expression checker
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def check(self, type_args: Inst, globals: Globals) -> "CheckedFunctionDef":

@override
def check_call(
self, args: list[ast.expr], ty: Type, node: AstNode, ctx: Context
self, args: list[ast.expr], ty: Type, node: ast.Call, ctx: Context
) -> tuple[ast.expr, Subst]:
"""Checks the return type of a function call against a given type."""
# Use default implementation from the expression checker
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def load(self, dfg: DFContainer, ctx: CompilerContext, node: AstNode) -> Wire:

@override
def check_call(
self, args: list[ast.expr], ty: Type, node: AstNode, ctx: Context
self, args: list[ast.expr], ty: Type, node: ast.Call, ctx: Context
) -> tuple[ast.expr, Subst]:
available_sigs: list[OverloadVariant] = []
for def_id in self.func_ids:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,7 @@ def compile_outer(

@override
def check_call(
self, args: list[ast.expr], ty: Type, node: AstNode, ctx: Context
self, args: list[ast.expr], ty: Type, node: ast.Call, ctx: Context
) -> tuple[ast.expr, Subst]:
"""Checks the return type of a function call against a given type."""
# Use default implementation from the expression checker
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ class TracedFunctionDef(RawTracedFunctionDef, CallableDef, CompilableDef):

@override
def check_call(
self, args: list[ast.expr], ty: Type, node: AstNode, ctx: Context
self, args: list[ast.expr], ty: Type, node: ast.Call, ctx: Context
) -> tuple[ast.expr, Subst]:
"""Checks the return type of a function call against a given type."""
# Use default implementation from the expression checker
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class CallableDef(ValueDef):

@abstractmethod
def check_call(
self, args: list[ast.expr], ty: Type, node: AstNode, ctx: "Context"
self, args: list[ast.expr], ty: Type, node: ast.Call, ctx: "Context"
) -> tuple[ast.expr, Subst]:
"""Checks the return type of a function call against a given type."""

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from typing_extensions import assert_never, override

from guppylang_internals.ast_util import get_type, with_loc, with_type
from guppylang_internals.ast_util import fake_call, get_type, with_loc, with_type
from guppylang_internals.checker.core import Context, Variable
from guppylang_internals.checker.errors.generic import UnsupportedError
from guppylang_internals.checker.errors.type_errors import (
Expand Down Expand Up @@ -470,7 +470,8 @@ def to_sized_iter(
sized_iter_ty = sized_iter_type(range_ty, size)
make_sized_iter = ENGINE.get_instance_func(sized_iter_ty, "__new__")
assert make_sized_iter is not None
sized_iter, _ = make_sized_iter.check_call([iterator], sized_iter_ty, iterator, ctx)
call = with_loc(iterator, fake_call("__new__", [iterator]))
sized_iter, _ = make_sized_iter.check_call([iterator], sized_iter_ty, call, ctx)
return sized_iter, sized_iter_ty


Expand Down
Loading