Skip to content

Commit c292b9a

Browse files
committed
Implement async steps.
1 parent 8fd0fd5 commit c292b9a

File tree

5 files changed

+770
-9
lines changed

5 files changed

+770
-9
lines changed

src/pytest_bdd/asyncio.py

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from pytest_bdd.steps import async_given, async_then, async_when
2+
3+
__all__ = ["async_given", "async_when", "async_then"]

src/pytest_bdd/scenario.py

+78-2
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,10 @@
1212
"""
1313
from __future__ import annotations
1414

15+
import asyncio
1516
import contextlib
17+
import functools
18+
import inspect
1619
import logging
1720
import os
1821
import re
@@ -34,7 +37,6 @@
3437

3538
from .parser import Feature, Scenario, ScenarioTemplate, Step
3639

37-
3840
logger = logging.getLogger(__name__)
3941

4042

@@ -156,7 +158,14 @@ def _execute_step_function(
156158

157159
request.config.hook.pytest_bdd_before_step_call(**kw)
158160
# Execute the step as if it was a pytest fixture, so that we can allow "yield" statements in it
159-
return_value = call_fixture_func(fixturefunc=context.step_func, request=request, kwargs=kwargs)
161+
step_func = context.step_func
162+
if context.is_async:
163+
if inspect.isasyncgenfunction(context.step_func):
164+
step_func = _wrap_asyncgen(request, context.step_func)
165+
elif inspect.iscoroutinefunction(context.step_func):
166+
step_func = _wrap_coroutine(context.step_func)
167+
168+
return_value = call_fixture_func(fixturefunc=step_func, request=request, kwargs=kwargs)
160169
except Exception as exception:
161170
request.config.hook.pytest_bdd_step_error(exception=exception, **kw)
162171
raise
@@ -167,6 +176,73 @@ def _execute_step_function(
167176
request.config.hook.pytest_bdd_after_step(**kw)
168177

169178

179+
def _wrap_asyncgen(request: FixtureRequest, func: Callable) -> Callable:
180+
"""Wrapper for an async_generator function.
181+
182+
This will wrap the function in a synchronized method to return the first
183+
yielded value from the generator. A finalizer will be added to the fixture
184+
to ensure that no other values are yielded and that the loop is closed.
185+
186+
:param request: The fixture request.
187+
:param func: The function to wrap.
188+
189+
:returns: The wrapped function.
190+
"""
191+
192+
@functools.wraps(func)
193+
def _wrapper(*args, **kwargs):
194+
try:
195+
loop, created = asyncio.get_running_loop(), False
196+
except RuntimeError:
197+
loop, created = asyncio.get_event_loop_policy().new_event_loop(), True
198+
199+
async_obj = func(*args, **kwargs)
200+
201+
def _finalizer() -> None:
202+
"""Ensure no more values are yielded and close the loop."""
203+
try:
204+
loop.run_until_complete(async_obj.__anext__())
205+
except StopAsyncIteration:
206+
pass
207+
else:
208+
raise ValueError("Async generator must only yield once.")
209+
210+
if created:
211+
loop.close()
212+
213+
value = loop.run_until_complete(async_obj.__anext__())
214+
request.addfinalizer(_finalizer)
215+
216+
return value
217+
218+
return _wrapper
219+
220+
221+
def _wrap_coroutine(func: Callable) -> Callable:
222+
"""Wrapper for a coroutine function.
223+
224+
:param func: The function to wrap.
225+
226+
:returns: The wrapped function.
227+
"""
228+
229+
@functools.wraps(func)
230+
def _wrapper(*args, **kwargs):
231+
try:
232+
loop, created = asyncio.get_running_loop(), False
233+
except RuntimeError:
234+
loop, created = asyncio.get_event_loop_policy().new_event_loop(), True
235+
236+
try:
237+
async_obj = func(*args, **kwargs)
238+
return loop.run_until_complete(async_obj)
239+
finally:
240+
if created:
241+
loop.close()
242+
243+
return _wrapper
244+
245+
170246
def _execute_scenario(feature: Feature, scenario: Scenario, request: FixtureRequest) -> None:
171247
"""Execute the scenario.
172248

src/pytest_bdd/steps.py

+77-3
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ class StepFunctionContext:
6666
parser: StepParser
6767
converters: dict[str, Callable[..., Any]] = field(default_factory=dict)
6868
target_fixture: str | None = None
69+
is_async: bool = False
6970

7071

7172
def get_step_fixture_name(step: Step) -> str:
@@ -78,6 +79,7 @@ def given(
7879
converters: dict[str, Callable] | None = None,
7980
target_fixture: str | None = None,
8081
stacklevel: int = 1,
82+
is_async: bool = False,
8183
) -> Callable:
8284
"""Given step decorator.
8385
@@ -86,17 +88,62 @@ def given(
8688
{<param_name>: <converter function>}.
8789
:param target_fixture: Target fixture name to replace by steps definition function.
8890
:param stacklevel: Stack level to find the caller frame. This is used when injecting the step definition fixture.
91+
:param is_async: True if the step is asynchronous. (Default: False)
8992
9093
:return: Decorator function for the step.
9194
"""
92-
return step(name, GIVEN, converters=converters, target_fixture=target_fixture, stacklevel=stacklevel)
95+
return step(
96+
name, GIVEN, converters=converters, target_fixture=target_fixture, stacklevel=stacklevel, is_async=is_async
97+
)
98+
99+
100+
def async_given(
101+
name: str | StepParser,
102+
converters: dict[str, Callable] | None = None,
103+
target_fixture: str | None = None,
104+
stacklevel: int = 1,
105+
) -> Callable:
106+
"""Async Given step decorator.
107+
108+
:param name: Step name or a parser object.
109+
:param converters: Optional `dict` of the argument or parameter converters in form
110+
{<param_name>: <converter function>}.
111+
:param target_fixture: Target fixture name to replace by steps definition function.
112+
:param stacklevel: Stack level to find the caller frame. This is used when injecting the step definition fixture.
113+
114+
:return: Decorator function for the step.
115+
"""
116+
return given(name, converters=converters, target_fixture=target_fixture, stacklevel=stacklevel, is_async=True)
93117

94118

95119
def when(
96120
name: str | StepParser,
97121
converters: dict[str, Callable] | None = None,
98122
target_fixture: str | None = None,
99123
stacklevel: int = 1,
124+
is_async: bool = False,
125+
) -> Callable:
126+
"""When step decorator.
127+
128+
:param name: Step name or a parser object.
129+
:param converters: Optional `dict` of the argument or parameter converters in form
130+
{<param_name>: <converter function>}.
131+
:param target_fixture: Target fixture name to replace by steps definition function.
132+
:param stacklevel: Stack level to find the caller frame. This is used when injecting the step definition fixture.
133+
:param is_async: True if the step is asynchronous. (Default: False)
134+
135+
:return: Decorator function for the step.
136+
"""
137+
return step(
138+
name, WHEN, converters=converters, target_fixture=target_fixture, stacklevel=stacklevel, is_async=is_async
139+
)
140+
141+
142+
def async_when(
143+
name: str | StepParser,
144+
converters: dict[str, Callable] | None = None,
145+
target_fixture: str | None = None,
146+
stacklevel: int = 1,
100147
) -> Callable:
101148
"""When step decorator.
102149
@@ -108,14 +155,15 @@ def when(
108155
109156
:return: Decorator function for the step.
110157
"""
111-
return step(name, WHEN, converters=converters, target_fixture=target_fixture, stacklevel=stacklevel)
158+
return when(name, converters=converters, target_fixture=target_fixture, stacklevel=stacklevel, is_async=True)
112159

113160

114161
def then(
115162
name: str | StepParser,
116163
converters: dict[str, Callable] | None = None,
117164
target_fixture: str | None = None,
118165
stacklevel: int = 1,
166+
is_async: bool = False,
119167
) -> Callable:
120168
"""Then step decorator.
121169
@@ -124,10 +172,32 @@ def then(
124172
{<param_name>: <converter function>}.
125173
:param target_fixture: Target fixture name to replace by steps definition function.
126174
:param stacklevel: Stack level to find the caller frame. This is used when injecting the step definition fixture.
175+
:param is_async: True if the step is asynchronous. (Default: False)
127176
128177
:return: Decorator function for the step.
129178
"""
130-
return step(name, THEN, converters=converters, target_fixture=target_fixture, stacklevel=stacklevel)
179+
return step(
180+
name, THEN, converters=converters, target_fixture=target_fixture, stacklevel=stacklevel, is_async=is_async
181+
)
182+
183+
184+
def async_then(
185+
name: str | StepParser,
186+
converters: dict[str, Callable] | None = None,
187+
target_fixture: str | None = None,
188+
stacklevel: int = 1,
189+
) -> Callable:
190+
"""Then step decorator.
191+
192+
:param name: Step name or a parser object.
193+
:param converters: Optional `dict` of the argument or parameter converters in form
194+
{<param_name>: <converter function>}.
195+
:param target_fixture: Target fixture name to replace by steps definition function.
196+
:param stacklevel: Stack level to find the caller frame. This is used when injecting the step definition fixture.
197+
198+
:return: Decorator function for the step.
199+
"""
200+
return step(name, THEN, converters=converters, target_fixture=target_fixture, stacklevel=stacklevel, is_async=True)
131201

132202

133203
def step(
@@ -136,6 +206,7 @@ def step(
136206
converters: dict[str, Callable] | None = None,
137207
target_fixture: str | None = None,
138208
stacklevel: int = 1,
209+
is_async: bool = False,
139210
) -> Callable[[TCallable], TCallable]:
140211
"""Generic step decorator.
141212
@@ -144,6 +215,7 @@ def step(
144215
:param converters: Optional step arguments converters mapping.
145216
:param target_fixture: Optional fixture name to replace by step definition.
146217
:param stacklevel: Stack level to find the caller frame. This is used when injecting the step definition fixture.
218+
:param is_async: True if the step is asynchronous. (Default: False)
147219
148220
:return: Decorator function for the step.
149221
@@ -165,6 +237,7 @@ def decorator(func: TCallable) -> TCallable:
165237
parser=parser,
166238
converters=converters,
167239
target_fixture=target_fixture,
240+
is_async=is_async,
168241
)
169242

170243
def step_function_marker() -> StepFunctionContext:
@@ -177,6 +250,7 @@ def step_function_marker() -> StepFunctionContext:
177250
f"{StepNamePrefix.step_def.value}_{type_ or '*'}_{parser.name}", seen=caller_locals.keys()
178251
)
179252
caller_locals[fixture_step_name] = pytest.fixture(name=fixture_step_name)(step_function_marker)
253+
180254
return func
181255

182256
return decorator

0 commit comments

Comments
 (0)