Skip to content

Commit e391c41

Browse files
jkmntvytas7
andauthored
feat(typing): add middleware protocols (#2390)
* typing: add middleware protocols * ci: fix * chore: fix a bad master merge (but with typing issues now) * fix(typing): restore the `*` to prevent overload overlap error --------- Co-authored-by: Vytautas Liuolia <vytautas.liuolia@gmail.com>
1 parent fe6996a commit e391c41

5 files changed

Lines changed: 196 additions & 27 deletions

File tree

falcon/_typing.py

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,3 +198,147 @@ def __call__(self, media: Any, content_type: Optional[str] = ...) -> bytes: ...
198198
DeserializeSync = Callable[[bytes], Any]
199199

200200
Responder = Union[ResponderMethod, AsgiResponderMethod]
201+
202+
203+
# Middleware
204+
class MiddlewareWithProcessRequest(Protocol):
205+
"""WSGI Middleware with request handler."""
206+
207+
def process_request(self, req: Request, resp: Response) -> None: ...
208+
209+
210+
class MiddlewareWithProcessResource(Protocol):
211+
"""WSGI Middleware with resource handler."""
212+
213+
def process_resource(
214+
self,
215+
req: Request,
216+
resp: Response,
217+
resource: object,
218+
params: Dict[str, Any],
219+
) -> None: ...
220+
221+
222+
class MiddlewareWithProcessResponse(Protocol):
223+
"""WSGI Middleware with response handler."""
224+
225+
def process_response(
226+
self, req: Request, resp: Response, resource: object, req_succeeded: bool
227+
) -> None: ...
228+
229+
230+
class AsgiMiddlewareWithProcessStartup(Protocol):
231+
"""ASGI middleware with startup handler."""
232+
233+
async def process_startup(
234+
self, scope: Mapping[str, Any], event: Mapping[str, Any]
235+
) -> None: ...
236+
237+
238+
class AsgiMiddlewareWithProcessShutdown(Protocol):
239+
"""ASGI middleware with shutdown handler."""
240+
241+
async def process_shutdown(
242+
self, scope: Mapping[str, Any], event: Mapping[str, Any]
243+
) -> None: ...
244+
245+
246+
class AsgiMiddlewareWithProcessRequest(Protocol):
247+
"""ASGI middleware with request handler."""
248+
249+
async def process_request(self, req: AsgiRequest, resp: AsgiResponse) -> None: ...
250+
251+
252+
class AsgiMiddlewareWithProcessResource(Protocol):
253+
"""ASGI middleware with resource handler."""
254+
255+
async def process_resource(
256+
self,
257+
req: AsgiRequest,
258+
resp: AsgiResponse,
259+
resource: object,
260+
params: Mapping[str, Any],
261+
) -> None: ...
262+
263+
264+
class AsgiMiddlewareWithProcessResponse(Protocol):
265+
"""ASGI middleware with response handler."""
266+
267+
async def process_response(
268+
self,
269+
req: AsgiRequest,
270+
resp: AsgiResponse,
271+
resource: object,
272+
req_succeeded: bool,
273+
) -> None: ...
274+
275+
276+
class MiddlewareWithAsyncProcessRequestWs(Protocol):
277+
"""ASGI middleware with WebSocket request handler."""
278+
279+
async def process_request_ws(self, req: AsgiRequest, ws: WebSocket) -> None: ...
280+
281+
282+
class MiddlewareWithAsyncProcessResourceWs(Protocol):
283+
"""ASGI middleware with WebSocket resource handler."""
284+
285+
async def process_resource_ws(
286+
self,
287+
req: AsgiRequest,
288+
ws: WebSocket,
289+
resource: object,
290+
params: Mapping[str, Any],
291+
) -> None: ...
292+
293+
294+
class UniversalMiddlewareWithProcessRequest(MiddlewareWithProcessRequest, Protocol):
295+
"""WSGI/ASGI middleware with request handler."""
296+
297+
async def process_request_async(
298+
self, req: AsgiRequest, resp: AsgiResponse
299+
) -> None: ...
300+
301+
302+
class UniversalMiddlewareWithProcessResource(MiddlewareWithProcessResource, Protocol):
303+
"""WSGI/ASGI middleware with resource handler."""
304+
305+
async def process_resource_async(
306+
self,
307+
req: AsgiRequest,
308+
resp: AsgiResponse,
309+
resource: object,
310+
params: Mapping[str, Any],
311+
) -> None: ...
312+
313+
314+
class UniversalMiddlewareWithProcessResponse(MiddlewareWithProcessResponse, Protocol):
315+
"""WSGI/ASGI middleware with response handler."""
316+
317+
async def process_response_async(
318+
self,
319+
req: AsgiRequest,
320+
resp: AsgiResponse,
321+
resource: object,
322+
req_succeeded: bool,
323+
) -> None: ...
324+
325+
326+
# NOTE(jkmnt): This typing is far from perfect due to the Python typing limitations,
327+
# but better than nothing. Middleware conforming to any protocol of the union
328+
# will pass the type check. Other protocols violations are not checked.
329+
Middleware = Union[
330+
MiddlewareWithProcessRequest,
331+
MiddlewareWithProcessResource,
332+
MiddlewareWithProcessResponse,
333+
]
334+
335+
AsgiMiddleware = Union[
336+
AsgiMiddlewareWithProcessRequest,
337+
AsgiMiddlewareWithProcessResource,
338+
AsgiMiddlewareWithProcessResponse,
339+
AsgiMiddlewareWithProcessStartup,
340+
AsgiMiddlewareWithProcessShutdown,
341+
UniversalMiddlewareWithProcessRequest,
342+
UniversalMiddlewareWithProcessResource,
343+
UniversalMiddlewareWithProcessResponse,
344+
]

falcon/app.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
from falcon._typing import ErrorHandler
5252
from falcon._typing import ErrorSerializer
5353
from falcon._typing import FindMethod
54+
from falcon._typing import Middleware
5455
from falcon._typing import ProcessResponseMethod
5556
from falcon._typing import ResponderCallable
5657
from falcon._typing import SinkCallable
@@ -286,7 +287,7 @@ def process_response(
286287
_static_routes: List[
287288
Tuple[routing.StaticRoute, routing.StaticRoute, Literal[False]]
288289
]
289-
_unprepared_middleware: List[object]
290+
_unprepared_middleware: List[Middleware]
290291

291292
# Attributes
292293
req_options: RequestOptions
@@ -305,7 +306,7 @@ def __init__(
305306
media_type: str = constants.DEFAULT_MEDIA_TYPE,
306307
request_type: Optional[Type[Request]] = None,
307308
response_type: Optional[Type[Response]] = None,
308-
middleware: Union[object, Iterable[object]] = None,
309+
middleware: Optional[Union[Middleware, Iterable[Middleware]]] = None,
309310
router: Optional[routing.CompiledRouter] = None,
310311
independent_middleware: bool = True,
311312
cors_enable: bool = False,
@@ -327,17 +328,17 @@ def __init__(
327328
# NOTE(kgriffs): Check to see if middleware is an
328329
# iterable, and if so, append the CORSMiddleware
329330
# instance.
330-
middleware = list(middleware) # type: ignore[arg-type]
331-
middleware.append(cm) # type: ignore[arg-type]
331+
middleware = list(cast(Iterable[Middleware], middleware))
332+
middleware.append(cm)
332333
except TypeError:
333334
# NOTE(kgriffs): Assume the middleware kwarg references
334335
# a single middleware component.
335-
middleware = [middleware, cm]
336+
middleware = [cast(Middleware, middleware), cm]
336337

337338
# set middleware
338339
self._unprepared_middleware = []
339340
self._independent_middleware = independent_middleware
340-
self.add_middleware(middleware)
341+
self.add_middleware(middleware or [])
341342

342343
self._router = router or routing.DefaultRouter()
343344
self._router_search = self._router.find
@@ -524,7 +525,9 @@ def router_options(self) -> routing.CompiledRouterOptions:
524525
"""
525526
return self._router.options
526527

527-
def add_middleware(self, middleware: Union[object, Iterable[object]]) -> None:
528+
def add_middleware(
529+
self, middleware: Union[Middleware, Iterable[Middleware]]
530+
) -> None:
528531
"""Add one or more additional middleware components.
529532
530533
Arguments:
@@ -535,20 +538,20 @@ def add_middleware(self, middleware: Union[object, Iterable[object]]) -> None:
535538
"""
536539

537540
# NOTE(kgriffs): Since this is called by the initializer, there is
538-
# the chance that middleware may be None.
541+
# the chance that middleware may be empty.
539542
if middleware:
540543
try:
541-
middleware = list(middleware) # type: ignore[call-overload]
544+
middleware = list(cast(Iterable[Middleware], middleware))
542545
except TypeError:
543546
# middleware is not iterable; assume it is just one bare component
544-
middleware = [middleware]
547+
middleware = [cast(Middleware, middleware)]
545548

546549
if (
547550
self._cors_enable
548551
and len(
549552
[
550553
mc
551-
for mc in self._unprepared_middleware + middleware # type: ignore[operator]
554+
for mc in self._unprepared_middleware + middleware
552555
if isinstance(mc, CORSMiddleware)
553556
]
554557
)
@@ -559,7 +562,7 @@ def add_middleware(self, middleware: Union[object, Iterable[object]]) -> None:
559562
'cors_enable (which already constructs one instance)'
560563
)
561564

562-
self._unprepared_middleware += middleware # type: ignore[arg-type]
565+
self._unprepared_middleware += middleware
563566

564567
# NOTE(kgriffs): Even if middleware is None or an empty list, we still
565568
# need to make sure self._middleware is initialized if this is the
@@ -1012,7 +1015,7 @@ def my_serializer(
10121015
# ------------------------------------------------------------------------
10131016

10141017
def _prepare_middleware(
1015-
self, middleware: List[object], independent_middleware: bool = False
1018+
self, middleware: List[Middleware], independent_middleware: bool = False
10161019
) -> helpers.PreparedMiddlewareResult:
10171020
return helpers.prepare_middleware(
10181021
middleware=middleware, independent_middleware=independent_middleware

falcon/app_helpers.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,13 @@
2020
from typing import IO, Iterable, List, Literal, Optional, overload, Tuple, Union
2121

2222
from falcon import util
23+
from falcon._typing import AsgiMiddleware
2324
from falcon._typing import AsgiProcessRequestMethod as APRequest
2425
from falcon._typing import AsgiProcessRequestWsMethod
2526
from falcon._typing import AsgiProcessResourceMethod as APResource
2627
from falcon._typing import AsgiProcessResourceWsMethod
2728
from falcon._typing import AsgiProcessResponseMethod as APResponse
29+
from falcon._typing import Middleware
2830
from falcon._typing import ProcessRequestMethod as PRequest
2931
from falcon._typing import ProcessResourceMethod as PResource
3032
from falcon._typing import ProcessResponseMethod as PResponse
@@ -62,24 +64,31 @@
6264

6365
@overload
6466
def prepare_middleware(
65-
middleware: Iterable, independent_middleware: bool = ..., asgi: Literal[False] = ...
67+
middleware: Iterable[Middleware],
68+
independent_middleware: bool = ...,
69+
asgi: Literal[False] = ...,
6670
) -> PreparedMiddlewareResult: ...
6771

6872

6973
@overload
7074
def prepare_middleware(
71-
middleware: Iterable, independent_middleware: bool = ..., asgi: Literal[True] = ...
75+
middleware: Iterable[AsgiMiddleware],
76+
independent_middleware: bool = ...,
77+
*,
78+
asgi: Literal[True],
7279
) -> AsyncPreparedMiddlewareResult: ...
7380

7481

7582
@overload
7683
def prepare_middleware(
77-
middleware: Iterable, independent_middleware: bool = ..., asgi: bool = ...
84+
middleware: Union[Iterable[Middleware], Iterable[AsgiMiddleware]],
85+
independent_middleware: bool = ...,
86+
asgi: bool = ...,
7887
) -> Union[PreparedMiddlewareResult, AsyncPreparedMiddlewareResult]: ...
7988

8089

8190
def prepare_middleware(
82-
middleware: Iterable[object],
91+
middleware: Union[Iterable[Middleware], Iterable[AsgiMiddleware]],
8392
independent_middleware: bool = False,
8493
asgi: bool = False,
8594
) -> Union[PreparedMiddlewareResult, AsyncPreparedMiddlewareResult]:
@@ -214,7 +223,7 @@ def prepare_middleware(
214223

215224

216225
def prepare_middleware_ws(
217-
middleware: Iterable[object],
226+
middleware: Iterable[AsgiMiddleware],
218227
) -> AsyncPreparedMiddlewareWsResult:
219228
"""Check middleware interfaces and prepare WebSocket methods for request handling.
220229

falcon/asgi/app.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
from falcon import routing
4343
from falcon._typing import _UNSET
4444
from falcon._typing import AsgiErrorHandler
45+
from falcon._typing import AsgiMiddleware
4546
from falcon._typing import AsgiReceive
4647
from falcon._typing import AsgiResponderCallable
4748
from falcon._typing import AsgiResponderWsCallable
@@ -356,6 +357,7 @@ async def process_resource_ws(
356357
_middleware_ws: AsyncPreparedMiddlewareWsResult
357358
_request_type: Type[Request]
358359
_response_type: Type[Response]
360+
_unprepared_middleware: List[AsgiMiddleware] # type: ignore[assignment]
359361

360362
ws_options: WebSocketOptions
361363
"""A set of behavioral options related to WebSocket connections.
@@ -368,7 +370,7 @@ def __init__(
368370
media_type: str = constants.DEFAULT_MEDIA_TYPE,
369371
request_type: Optional[Type[Request]] = None,
370372
response_type: Optional[Type[Response]] = None,
371-
middleware: Union[object, Iterable[object]] = None,
373+
middleware: Optional[Union[AsgiMiddleware, Iterable[AsgiMiddleware]]] = None,
372374
router: Optional[routing.CompiledRouter] = None,
373375
independent_middleware: bool = True,
374376
cors_enable: bool = False,
@@ -378,7 +380,7 @@ def __init__(
378380
media_type,
379381
request_type or Request,
380382
response_type or Response,
381-
middleware,
383+
middleware, # type: ignore[arg-type]
382384
router,
383385
independent_middleware,
384386
cors_enable,
@@ -1163,7 +1165,7 @@ async def _handle_websocket(
11631165
raise
11641166

11651167
def _prepare_middleware( # type: ignore[override]
1166-
self, middleware: List[object], independent_middleware: bool = False
1168+
self, middleware: List[AsgiMiddleware], independent_middleware: bool = False
11671169
) -> AsyncPreparedMiddlewareResult:
11681170
self._middleware_ws = prepare_middleware_ws(middleware)
11691171

falcon/middleware.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,17 @@
11
from __future__ import annotations
22

3-
from typing import Any, Iterable, Optional, Union
3+
from typing import Iterable, Optional, TYPE_CHECKING, Union
44

5-
from .request import Request
6-
from .response import Response
5+
from ._typing import UniversalMiddlewareWithProcessResponse
76

7+
if TYPE_CHECKING:
8+
from .asgi.request import Request as AsgiRequest
9+
from .asgi.response import Response as AsgiResponse
10+
from .request import Request
11+
from .response import Response
812

9-
class CORSMiddleware(object):
13+
14+
class CORSMiddleware(UniversalMiddlewareWithProcessResponse):
1015
"""CORS Middleware.
1116
1217
This middleware provides a simple out-of-the box CORS policy, including handling
@@ -141,5 +146,11 @@ def process_response(
141146
resp.set_header('Access-Control-Allow-Headers', allow_headers)
142147
resp.set_header('Access-Control-Max-Age', '86400') # 24 hours
143148

144-
async def process_response_async(self, *args: Any) -> None:
145-
self.process_response(*args)
149+
async def process_response_async(
150+
self,
151+
req: AsgiRequest,
152+
resp: AsgiResponse,
153+
resource: object,
154+
req_succeeded: bool,
155+
) -> None:
156+
self.process_response(req, resp, resource, req_succeeded)

0 commit comments

Comments
 (0)