diff --git a/tornado/httpserver.py b/tornado/httpserver.py index 8044a4f828..34ae0de62a 100644 --- a/tornado/httpserver.py +++ b/tornado/httpserver.py @@ -378,8 +378,29 @@ def headers_received( # TODO: either make context an official part of the # HTTPConnection interface or figure out some other way to do this. self.connection.context._apply_xheaders(headers) # type: ignore + start_line, headers = self.apply_forwarded_context(start_line, headers) + return self.delegate.headers_received(start_line, headers) + def apply_forwarded_context( + self, + start_line: Union[httputil.RequestStartLine, httputil.ResponseStartLine], + headers: httputil.HTTPHeaders, + ) -> Tuple[ + Union[httputil.RequestStartLine, httputil.ResponseStartLine], + httputil.HTTPHeaders, + ]: + """Apply X-Forwarded-Context header to requested uri""" + if isinstance(start_line, httputil.RequestStartLine): + # get path from X-Forwarded-Context + proxy_path = headers.get("X-Forwarded-Context", None) + if proxy_path: + # preserve only the path part + path = proxy_path.split("?", 1)[0] + start_line = start_line._replace(path=path) + + return start_line, headers + def data_received(self, chunk: bytes) -> Optional[Awaitable[None]]: return self.delegate.data_received(chunk) diff --git a/tornado/test/httpserver_test.py b/tornado/test/httpserver_test.py index 2a8b6a5b14..64deae848f 100644 --- a/tornado/test/httpserver_test.py +++ b/tornado/test/httpserver_test.py @@ -73,7 +73,7 @@ def finish(self): class HandlerBaseTestCase(AsyncHTTPTestCase): def get_app(self): - return Application([("/", self.__class__.Handler)]) + return Application([("/.*", self.__class__.Handler)]) def fetch_json(self, *args, **kwargs): response = self.fetch(*args, **kwargs) @@ -554,6 +554,7 @@ def get(self): dict( remote_ip=self.request.remote_ip, remote_protocol=self.request.protocol, + path=self.request.path, ) ) @@ -640,6 +641,14 @@ def test_scheme_headers(self): self.fetch_json("/", headers=bad_forwarded)["remote_protocol"], "http" ) + def test_forwarded_context(self): + self.assertEqual(self.fetch_json("/")["path"], "/") + + self.assertEqual( + self.fetch_json("/", headers={"X-Forwarded-Context": "/prefix"})["path"], + "/prefix", + ) + class SSLXHeaderTest(AsyncHTTPSTestCase, HandlerBaseTestCase): def get_app(self):