Skip to content

Commit 5e8d893

Browse files
authored
Merge pull request #37 from modern-python/28-feature-websockets-in-litestar
add ws support for litestar
2 parents d7f333d + 3901f79 commit 5e8d893

File tree

4 files changed

+74
-16
lines changed

4 files changed

+74
-16
lines changed

packages/modern-di-litestar/modern_di_litestar/main.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def setup_di(app: litestar.Litestar, scope: enum.IntEnum = DIScope.APP) -> Conta
1818

1919

2020
def prepare_di_dependencies() -> dict[str, Provide]:
21-
return {"request_di_container": Provide(build_di_container)}
21+
return {"di_container": Provide(build_di_container)}
2222

2323

2424
def fetch_di_container(app: litestar.Litestar) -> Container:
@@ -28,8 +28,14 @@ def fetch_di_container(app: litestar.Litestar) -> Container:
2828
async def build_di_container(
2929
request: litestar.Request[typing.Any, typing.Any, typing.Any],
3030
) -> typing.AsyncIterator[Container]:
31-
scope = DIScope.REQUEST
32-
context = {"request": request}
31+
context: dict[str, typing.Any] = {}
32+
scope: DIScope | None
33+
if isinstance(request, litestar.WebSocket):
34+
context["websocket"] = request
35+
scope = DIScope.SESSION
36+
else:
37+
context["request"] = request
38+
scope = DIScope.REQUEST
3339
container: Container = fetch_di_container(request.app)
3440
async with container.build_child_container(context=context, scope=scope) as request_container:
3541
yield request_container
@@ -39,11 +45,9 @@ async def build_di_container(
3945
class _Dependency(typing.Generic[T_co]):
4046
dependency: providers.AbstractProvider[T_co]
4147

42-
async def __call__(
43-
self, request_di_container: typing.Annotated[Container | None, Dependency()] = None
44-
) -> T_co | None:
45-
assert request_di_container
46-
return await self.dependency.async_resolve(request_di_container)
48+
async def __call__(self, di_container: typing.Annotated[Container | None, Dependency()] = None) -> T_co | None:
49+
assert di_container
50+
return await self.dependency.async_resolve(di_container)
4751

4852

4953
def FromDI(dependency: providers.AbstractProvider[T_co]) -> Provide: # noqa: N802

packages/modern-di-litestar/tests_litestar/test_routes.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,8 @@ async def read_root(method: str) -> None:
5555

5656
def test_factories_action_scope(client: TestClient[litestar.Litestar], app: litestar.Litestar) -> None:
5757
@litestar.get("/")
58-
async def read_root(request_di_container: modern_di.Container) -> None:
59-
with request_di_container.build_child_container() as action_container:
58+
async def read_root(di_container: modern_di.Container) -> None:
59+
with di_container.build_child_container() as action_container:
6060
action_factory_instance = action_factory.sync_resolve(action_container)
6161
assert isinstance(action_factory_instance, DependentCreator)
6262

Lines changed: 58 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,66 @@
1+
import typing
2+
13
import litestar
24
from litestar.testing import TestClient
5+
from modern_di import Container, Scope, providers
6+
from modern_di_litestar import FromDI
7+
8+
from tests_litestar.dependencies import DependentCreator, SimpleCreator
9+
10+
11+
def context_adapter_function(*, websocket: litestar.Request[typing.Any, typing.Any, typing.Any], **_: object) -> str:
12+
assert isinstance(websocket, litestar.WebSocket)
13+
return websocket.url.path
14+
15+
16+
app_factory = providers.Factory(Scope.APP, SimpleCreator, dep1="original")
17+
session_factory = providers.Factory(Scope.SESSION, DependentCreator, dep1=app_factory.cast)
18+
request_factory = providers.Factory(Scope.REQUEST, DependentCreator, dep1=app_factory.cast)
19+
context_adapter = providers.ContextAdapter(Scope.SESSION, context_adapter_function)
20+
21+
22+
async def test_factories(client: TestClient[litestar.Litestar], app: litestar.Litestar) -> None:
23+
@litestar.websocket_listener(
24+
"/ws",
25+
dependencies={"app_factory_instance": FromDI(app_factory), "session_factory_instance": FromDI(session_factory)},
26+
)
27+
async def websocket_handler(
28+
data: str,
29+
app_factory_instance: SimpleCreator,
30+
session_factory_instance: DependentCreator,
31+
) -> None:
32+
assert data == "test"
33+
assert isinstance(app_factory_instance, SimpleCreator)
34+
assert isinstance(session_factory_instance, DependentCreator)
35+
assert session_factory_instance.dep1 is not app_factory_instance
36+
37+
app.register(websocket_handler)
38+
39+
with client.websocket_connect("/ws") as websocket:
40+
websocket.send("test")
41+
42+
43+
async def test_factories_request_scope(client: TestClient[litestar.Litestar], app: litestar.Litestar) -> None:
44+
@litestar.websocket_listener("/ws")
45+
async def websocket_handler(data: str, di_container: Container) -> None:
46+
assert data == "test"
47+
with di_container.build_child_container() as request_container:
48+
request_factory_instance = request_factory.sync_resolve(request_container)
49+
assert isinstance(request_factory_instance, DependentCreator)
50+
51+
app.register(websocket_handler)
52+
53+
with client.websocket_connect("/ws") as websocket:
54+
websocket.send("test")
355

456

5-
async def test_websocket_not_supported(client: TestClient[litestar.Litestar], app: litestar.Litestar) -> None:
6-
async def websocket_handler(data: str) -> None:
7-
pass
57+
async def test_context_adapter(client: TestClient[litestar.Litestar], app: litestar.Litestar) -> None:
58+
@litestar.websocket_listener("/ws", dependencies={"path": FromDI(context_adapter)})
59+
async def websocket_handler(data: str, path: str) -> None:
60+
assert data == "test"
61+
assert path == "/ws"
862

9-
app.register(litestar.websocket_listener("/ws")(websocket_handler))
63+
app.register(websocket_handler)
1064

1165
with client.websocket_connect("/ws") as websocket:
1266
websocket.send("test")

packages/modern-di/modern_di/providers/abstract.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,8 @@ def __init__(
5858
super().__init__(scope)
5959
self._check_providers_scope(itertools.chain(args, kwargs.values()))
6060
self._creator: typing.Final = creator
61-
self._args: typing.Final = args
62-
self._kwargs: typing.Final = kwargs
61+
self._args: typing.Final[P.args] = args
62+
self._kwargs: typing.Final[P.kwargs] = kwargs
6363

6464
def _sync_resolve_args(self, container: Container) -> list[typing.Any]:
6565
return [x.sync_resolve(container) if isinstance(x, AbstractProvider) else x for x in self._args]

0 commit comments

Comments
 (0)