@@ -39,67 +39,103 @@ <h1 class="title">Module <code>supertokens_python.framework.fastapi.fastapi_midd
3939# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
4040# License for the specific language governing permissions and limitations
4141# under the License.
42- from __future__ import annotations
42+ from typing import Union
4343
44- from typing import TYPE_CHECKING, Union
4544
46- from supertokens_python.framework import BaseResponse
45+ def get_middleware():
46+ from supertokens_python import Supertokens
47+ from supertokens_python.utils import default_user_context
48+ from supertokens_python.exceptions import SuperTokensError
49+ from supertokens_python.framework import BaseResponse
50+ from supertokens_python.recipe.session import SessionContainer
51+ from supertokens_python.supertokens import manage_session_post_response
4752
48- if TYPE_CHECKING:
49- from fastapi import Request
53+ from starlette.requests import Request
54+ from starlette.responses import Response
55+ from starlette.types import ASGIApp, Message, Receive, Scope, Send
5056
57+ from supertokens_python.framework.fastapi.fastapi_request import (
58+ FastApiRequest,
59+ )
60+ from supertokens_python.framework.fastapi.fastapi_response import (
61+ FastApiResponse,
62+ )
5163
52- def get_middleware() :
53- from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
54- from supertokens_python.utils import default_user_context
64+ class ASGIMiddleware :
65+ def __init__(self, app: ASGIApp) -> None:
66+ self.app = app
5567
56- class Middleware(BaseHTTPMiddleware):
57- async def dispatch(self, request: Request, call_next: RequestResponseEndpoint):
58- from supertokens_python import Supertokens
59- from supertokens_python.exceptions import SuperTokensError
60- from supertokens_python.framework.fastapi.fastapi_request import (
61- FastApiRequest,
62- )
63- from supertokens_python.framework.fastapi.fastapi_response import (
64- FastApiResponse,
65- )
66- from supertokens_python.recipe.session import SessionContainer
67- from supertokens_python.supertokens import manage_session_post_response
68+ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
69+ if scope["type"] != "http": # we pass through the non-http requests, if any
70+ await self.app(scope, receive, send)
71+ return
6872
6973 st = Supertokens.get_instance()
70- from fastapi.responses import Response
7174
75+ request = Request(scope, receive=receive)
7276 custom_request = FastApiRequest(request)
73- response = FastApiResponse(Response())
7477 user_context = default_user_context(custom_request)
7578
7679 try:
80+ response = FastApiResponse(Response())
7781 result: Union[BaseResponse, None] = await st.middleware(
7882 custom_request, response, user_context
7983 )
8084 if result is None:
81- response = await call_next(request)
82- result = FastApiResponse(response)
85+ # This means that the supertokens middleware did not handle the request,
86+ # however, we may need to handle the header changes in the response,
87+ # based on response mutators used by the session.
88+ async def send_wrapper(message: Message):
89+ if message["type"] == "http.response.start":
90+ # Start message has the headers, so we update the headers here
91+ # by using `manage_session_post_response` function, which will
92+ # apply all the Response Mutators. In the end, we just replace
93+ # the updated headers in the message.
94+ if hasattr(request.state, "supertokens") and isinstance(
95+ request.state.supertokens, SessionContainer
96+ ):
97+ fapi_response = Response()
98+ fapi_response.raw_headers = message["headers"]
99+ response = FastApiResponse(fapi_response)
100+ manage_session_post_response(
101+ request.state.supertokens, response, user_context
102+ )
103+ message["headers"] = fapi_response.raw_headers
83104
105+ # For `http.response.start` message, we might have the headers updated,
106+ # otherwise, we just send all the messages as is
107+ await send(message)
108+
109+ await self.app(scope, receive, send_wrapper)
110+ return
111+
112+ # This means that the request was handled by the supertokens middleware
113+ # and hence we respond using the response object returned by the middleware.
84114 if hasattr(request.state, "supertokens") and isinstance(
85115 request.state.supertokens, SessionContainer
86116 ):
87117 manage_session_post_response(
88118 request.state.supertokens, result, user_context
89119 )
120+
90121 if isinstance(result, FastApiResponse):
91- return result.response
122+ await result.response(scope, receive, send)
123+ return
124+
125+ return
126+
92127 except SuperTokensError as e:
93128 response = FastApiResponse(Response())
94129 result: Union[BaseResponse, None] = await st.handle_supertokens_error(
95130 FastApiRequest(request), e, response, user_context
96131 )
97132 if isinstance(result, FastApiResponse):
98- return result.response
133+ await result.response(scope, receive, send)
134+ return
99135
100136 raise Exception("Should never come here")
101137
102- return Middleware </ code > </ pre >
138+ return ASGIMiddleware </ code > </ pre >
103139</ details >
104140</ section >
105141< section >
@@ -119,56 +155,99 @@ <h2 class="section-title" id="header-functions">Functions</h2>
119155< span > Expand source code</ span >
120156</ summary >
121157< pre > < code class ="python "> def get_middleware():
122- from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
158+ from supertokens_python import Supertokens
123159 from supertokens_python.utils import default_user_context
160+ from supertokens_python.exceptions import SuperTokensError
161+ from supertokens_python.framework import BaseResponse
162+ from supertokens_python.recipe.session import SessionContainer
163+ from supertokens_python.supertokens import manage_session_post_response
124164
125- class Middleware(BaseHTTPMiddleware):
126- async def dispatch(self, request: Request, call_next: RequestResponseEndpoint):
127- from supertokens_python import Supertokens
128- from supertokens_python.exceptions import SuperTokensError
129- from supertokens_python.framework.fastapi.fastapi_request import (
130- FastApiRequest,
131- )
132- from supertokens_python.framework.fastapi.fastapi_response import (
133- FastApiResponse,
134- )
135- from supertokens_python.recipe.session import SessionContainer
136- from supertokens_python.supertokens import manage_session_post_response
165+ from starlette.requests import Request
166+ from starlette.responses import Response
167+ from starlette.types import ASGIApp, Message, Receive, Scope, Send
168+
169+ from supertokens_python.framework.fastapi.fastapi_request import (
170+ FastApiRequest,
171+ )
172+ from supertokens_python.framework.fastapi.fastapi_response import (
173+ FastApiResponse,
174+ )
175+
176+ class ASGIMiddleware:
177+ def __init__(self, app: ASGIApp) -> None:
178+ self.app = app
179+
180+ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
181+ if scope["type"] != "http": # we pass through the non-http requests, if any
182+ await self.app(scope, receive, send)
183+ return
137184
138185 st = Supertokens.get_instance()
139- from fastapi.responses import Response
140186
187+ request = Request(scope, receive=receive)
141188 custom_request = FastApiRequest(request)
142- response = FastApiResponse(Response())
143189 user_context = default_user_context(custom_request)
144190
145191 try:
192+ response = FastApiResponse(Response())
146193 result: Union[BaseResponse, None] = await st.middleware(
147194 custom_request, response, user_context
148195 )
149196 if result is None:
150- response = await call_next(request)
151- result = FastApiResponse(response)
197+ # This means that the supertokens middleware did not handle the request,
198+ # however, we may need to handle the header changes in the response,
199+ # based on response mutators used by the session.
200+ async def send_wrapper(message: Message):
201+ if message["type"] == "http.response.start":
202+ # Start message has the headers, so we update the headers here
203+ # by using `manage_session_post_response` function, which will
204+ # apply all the Response Mutators. In the end, we just replace
205+ # the updated headers in the message.
206+ if hasattr(request.state, "supertokens") and isinstance(
207+ request.state.supertokens, SessionContainer
208+ ):
209+ fapi_response = Response()
210+ fapi_response.raw_headers = message["headers"]
211+ response = FastApiResponse(fapi_response)
212+ manage_session_post_response(
213+ request.state.supertokens, response, user_context
214+ )
215+ message["headers"] = fapi_response.raw_headers
216+
217+ # For `http.response.start` message, we might have the headers updated,
218+ # otherwise, we just send all the messages as is
219+ await send(message)
152220
221+ await self.app(scope, receive, send_wrapper)
222+ return
223+
224+ # This means that the request was handled by the supertokens middleware
225+ # and hence we respond using the response object returned by the middleware.
153226 if hasattr(request.state, "supertokens") and isinstance(
154227 request.state.supertokens, SessionContainer
155228 ):
156229 manage_session_post_response(
157230 request.state.supertokens, result, user_context
158231 )
232+
159233 if isinstance(result, FastApiResponse):
160- return result.response
234+ await result.response(scope, receive, send)
235+ return
236+
237+ return
238+
161239 except SuperTokensError as e:
162240 response = FastApiResponse(Response())
163241 result: Union[BaseResponse, None] = await st.handle_supertokens_error(
164242 FastApiRequest(request), e, response, user_context
165243 )
166244 if isinstance(result, FastApiResponse):
167- return result.response
245+ await result.response(scope, receive, send)
246+ return
168247
169248 raise Exception("Should never come here")
170249
171- return Middleware </ code > </ pre >
250+ return ASGIMiddleware </ code > </ pre >
172251</ details >
173252</ dd >
174253</ dl >
0 commit comments