Skip to content

Commit 8ee7926

Browse files
committed
improved fix from #688, but still pyscript doesn't
correctly handle async function declarations and creating coros when you call them without await.
1 parent a5f3cde commit 8ee7926

File tree

2 files changed

+48
-7
lines changed

2 files changed

+48
-7
lines changed

custom_components/pyscript/eval.py

+15-7
Original file line numberDiff line numberDiff line change
@@ -313,7 +313,7 @@ def getattr(self):
313313
class EvalFunc:
314314
"""Class for a callable pyscript function."""
315315

316-
def __init__(self, func_def, code_list, code_str, global_ctx):
316+
def __init__(self, func_def, code_list, code_str, global_ctx, async_func=False):
317317
"""Initialize a function calling context."""
318318
self.func_def = func_def
319319
self.name = func_def.name
@@ -338,6 +338,7 @@ def __init__(self, func_def, code_list, code_str, global_ctx):
338338
self.trigger = []
339339
self.trigger_service = set()
340340
self.has_closure = False
341+
self.async_func = async_func
341342

342343
def get_name(self):
343344
"""Return the function name."""
@@ -930,14 +931,18 @@ async def ast_not_implemented(self, arg, *args):
930931
name = "ast_" + arg.__class__.__name__.lower()
931932
raise NotImplementedError(f"{self.name}: not implemented ast " + name)
932933

933-
async def aeval(self, arg, undefined_check=True):
934+
async def aeval(self, arg, undefined_check=True, do_await=True):
934935
"""Vector to specific function based on ast class type."""
935936
name = "ast_" + arg.__class__.__name__.lower()
936937
try:
937938
if hasattr(arg, "lineno"):
938939
self.lineno = arg.lineno
939940
self.col_offset = arg.col_offset
940-
val = await getattr(self, name, self.ast_not_implemented)(arg)
941+
val = (
942+
await getattr(self, name, self.ast_not_implemented)(arg)
943+
if do_await
944+
else getattr(self, name, self.ast_not_implemented)(arg)
945+
)
941946
if undefined_check and isinstance(val, EvalName):
942947
raise NameError(f"name '{val.name}' is not defined")
943948
return val
@@ -1102,7 +1107,7 @@ async def ast_classdef(self, arg):
11021107
del sym_table["__init__"]
11031108
sym_table_assign[arg.name].set(type(arg.name, tuple(bases), sym_table))
11041109

1105-
async def ast_functiondef(self, arg):
1110+
async def ast_functiondef(self, arg, async_func=False):
11061111
"""Evaluate function definition."""
11071112
other_dec = []
11081113
dec_name = None
@@ -1158,7 +1163,7 @@ async def executor_wrap(*args, **kwargs):
11581163
self.sym_table[arg.name].set(func)
11591164
return
11601165

1161-
func = EvalFunc(arg, self.code_list, self.code_str, self.global_ctx)
1166+
func = EvalFunc(arg, self.code_list, self.code_str, self.global_ctx, async_func)
11621167
await func.eval_defaults(self)
11631168
await func.resolve_nonlocals(self)
11641169
name = func.get_name()
@@ -1215,7 +1220,7 @@ async def ast_lambda(self, arg):
12151220

12161221
async def ast_asyncfunctiondef(self, arg):
12171222
"""Evaluate async function definition."""
1218-
return await self.ast_functiondef(arg)
1223+
return await self.ast_functiondef(arg, async_func=True)
12191224

12201225
async def ast_try(self, arg):
12211226
"""Execute try...except statement."""
@@ -2020,7 +2025,10 @@ async def ast_formattedvalue(self, arg):
20202025

20212026
async def ast_await(self, arg):
20222027
"""Evaluate await expr."""
2023-
return await self.aeval(arg.value)
2028+
coro = await self.aeval(arg.value, do_await=False)
2029+
if coro and asyncio.iscoroutine(coro):
2030+
return await coro
2031+
return coro
20242032

20252033
async def get_target_names(self, lhs):
20262034
"""Recursively find all the target names mentioned in the AST tree."""

tests/test_unit_eval.py

+33
Original file line numberDiff line numberDiff line change
@@ -1415,6 +1415,39 @@ async def func():
14151415
""",
14161416
42,
14171417
],
1418+
[
1419+
"""
1420+
import asyncio
1421+
async def coro():
1422+
await asyncio.sleep(0.1)
1423+
return "done"
1424+
1425+
await coro()
1426+
""",
1427+
"done",
1428+
],
1429+
[
1430+
"""
1431+
import asyncio
1432+
1433+
@pyscript_compile
1434+
async def nested():
1435+
await asyncio.sleep(1e-8)
1436+
return 42
1437+
1438+
@pyscript_compile
1439+
async def run():
1440+
task = asyncio.create_task(nested())
1441+
1442+
# "task" can now be used to cancel "nested()", or
1443+
# can simply be awaited to wait until it is complete:
1444+
await task
1445+
return "done"
1446+
1447+
await run()
1448+
""",
1449+
"done",
1450+
],
14181451
]
14191452

14201453

0 commit comments

Comments
 (0)