|
| 1 | +import typing |
| 2 | + |
1 | 3 | import litestar
|
2 | 4 | from litestar.testing import TestClient
|
| 5 | +from modern_di import Container, Scope, providers |
| 6 | +from modern_di_litestar import FromDI |
| 7 | + |
| 8 | +from tests_litestar.dependencies import DependentCreator, SimpleCreator |
| 9 | + |
| 10 | + |
| 11 | +def context_adapter_function(*, websocket: litestar.Request[typing.Any, typing.Any, typing.Any], **_: object) -> str: |
| 12 | + assert isinstance(websocket, litestar.WebSocket) |
| 13 | + return websocket.url.path |
| 14 | + |
| 15 | + |
| 16 | +app_factory = providers.Factory(Scope.APP, SimpleCreator, dep1="original") |
| 17 | +session_factory = providers.Factory(Scope.SESSION, DependentCreator, dep1=app_factory.cast) |
| 18 | +request_factory = providers.Factory(Scope.REQUEST, DependentCreator, dep1=app_factory.cast) |
| 19 | +context_adapter = providers.ContextAdapter(Scope.SESSION, context_adapter_function) |
| 20 | + |
| 21 | + |
| 22 | +async def test_factories(client: TestClient[litestar.Litestar], app: litestar.Litestar) -> None: |
| 23 | + @litestar.websocket_listener( |
| 24 | + "/ws", |
| 25 | + dependencies={"app_factory_instance": FromDI(app_factory), "session_factory_instance": FromDI(session_factory)}, |
| 26 | + ) |
| 27 | + async def websocket_handler( |
| 28 | + data: str, |
| 29 | + app_factory_instance: SimpleCreator, |
| 30 | + session_factory_instance: DependentCreator, |
| 31 | + ) -> None: |
| 32 | + assert data == "test" |
| 33 | + assert isinstance(app_factory_instance, SimpleCreator) |
| 34 | + assert isinstance(session_factory_instance, DependentCreator) |
| 35 | + assert session_factory_instance.dep1 is not app_factory_instance |
| 36 | + |
| 37 | + app.register(websocket_handler) |
| 38 | + |
| 39 | + with client.websocket_connect("/ws") as websocket: |
| 40 | + websocket.send("test") |
| 41 | + |
| 42 | + |
| 43 | +async def test_factories_request_scope(client: TestClient[litestar.Litestar], app: litestar.Litestar) -> None: |
| 44 | + @litestar.websocket_listener("/ws") |
| 45 | + async def websocket_handler(data: str, di_container: Container) -> None: |
| 46 | + assert data == "test" |
| 47 | + with di_container.build_child_container() as request_container: |
| 48 | + request_factory_instance = request_factory.sync_resolve(request_container) |
| 49 | + assert isinstance(request_factory_instance, DependentCreator) |
| 50 | + |
| 51 | + app.register(websocket_handler) |
| 52 | + |
| 53 | + with client.websocket_connect("/ws") as websocket: |
| 54 | + websocket.send("test") |
3 | 55 |
|
4 | 56 |
|
5 |
| -async def test_websocket_not_supported(client: TestClient[litestar.Litestar], app: litestar.Litestar) -> None: |
6 |
| - async def websocket_handler(data: str) -> None: |
7 |
| - pass |
| 57 | +async def test_context_adapter(client: TestClient[litestar.Litestar], app: litestar.Litestar) -> None: |
| 58 | + @litestar.websocket_listener("/ws", dependencies={"path": FromDI(context_adapter)}) |
| 59 | + async def websocket_handler(data: str, path: str) -> None: |
| 60 | + assert data == "test" |
| 61 | + assert path == "/ws" |
8 | 62 |
|
9 |
| - app.register(litestar.websocket_listener("/ws")(websocket_handler)) |
| 63 | + app.register(websocket_handler) |
10 | 64 |
|
11 | 65 | with client.websocket_connect("/ws") as websocket:
|
12 | 66 | websocket.send("test")
|
0 commit comments