@@ -313,7 +313,7 @@ def getattr(self):
313
313
class EvalFunc :
314
314
"""Class for a callable pyscript function."""
315
315
316
- def __init__ (self , func_def , code_list , code_str , global_ctx ):
316
+ def __init__ (self , func_def , code_list , code_str , global_ctx , async_func = False ):
317
317
"""Initialize a function calling context."""
318
318
self .func_def = func_def
319
319
self .name = func_def .name
@@ -338,6 +338,7 @@ def __init__(self, func_def, code_list, code_str, global_ctx):
338
338
self .trigger = []
339
339
self .trigger_service = set ()
340
340
self .has_closure = False
341
+ self .async_func = async_func
341
342
342
343
def get_name (self ):
343
344
"""Return the function name."""
@@ -930,14 +931,18 @@ async def ast_not_implemented(self, arg, *args):
930
931
name = "ast_" + arg .__class__ .__name__ .lower ()
931
932
raise NotImplementedError (f"{ self .name } : not implemented ast " + name )
932
933
933
- async def aeval (self , arg , undefined_check = True ):
934
+ async def aeval (self , arg , undefined_check = True , do_await = True ):
934
935
"""Vector to specific function based on ast class type."""
935
936
name = "ast_" + arg .__class__ .__name__ .lower ()
936
937
try :
937
938
if hasattr (arg , "lineno" ):
938
939
self .lineno = arg .lineno
939
940
self .col_offset = arg .col_offset
940
- val = await getattr (self , name , self .ast_not_implemented )(arg )
941
+ val = (
942
+ await getattr (self , name , self .ast_not_implemented )(arg )
943
+ if do_await
944
+ else getattr (self , name , self .ast_not_implemented )(arg )
945
+ )
941
946
if undefined_check and isinstance (val , EvalName ):
942
947
raise NameError (f"name '{ val .name } ' is not defined" )
943
948
return val
@@ -1102,7 +1107,7 @@ async def ast_classdef(self, arg):
1102
1107
del sym_table ["__init__" ]
1103
1108
sym_table_assign [arg .name ].set (type (arg .name , tuple (bases ), sym_table ))
1104
1109
1105
- async def ast_functiondef (self , arg ):
1110
+ async def ast_functiondef (self , arg , async_func = False ):
1106
1111
"""Evaluate function definition."""
1107
1112
other_dec = []
1108
1113
dec_name = None
@@ -1158,7 +1163,7 @@ async def executor_wrap(*args, **kwargs):
1158
1163
self .sym_table [arg .name ].set (func )
1159
1164
return
1160
1165
1161
- func = EvalFunc (arg , self .code_list , self .code_str , self .global_ctx )
1166
+ func = EvalFunc (arg , self .code_list , self .code_str , self .global_ctx , async_func )
1162
1167
await func .eval_defaults (self )
1163
1168
await func .resolve_nonlocals (self )
1164
1169
name = func .get_name ()
@@ -1215,7 +1220,7 @@ async def ast_lambda(self, arg):
1215
1220
1216
1221
async def ast_asyncfunctiondef (self , arg ):
1217
1222
"""Evaluate async function definition."""
1218
- return await self .ast_functiondef (arg )
1223
+ return await self .ast_functiondef (arg , async_func = True )
1219
1224
1220
1225
async def ast_try (self , arg ):
1221
1226
"""Execute try...except statement."""
@@ -2020,7 +2025,10 @@ async def ast_formattedvalue(self, arg):
2020
2025
2021
2026
async def ast_await (self , arg ):
2022
2027
"""Evaluate await expr."""
2023
- return await self .aeval (arg .value )
2028
+ coro = await self .aeval (arg .value , do_await = False )
2029
+ if coro and asyncio .iscoroutine (coro ):
2030
+ return await coro
2031
+ return coro
2024
2032
2025
2033
async def get_target_names (self , lhs ):
2026
2034
"""Recursively find all the target names mentioned in the AST tree."""
0 commit comments