Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions tornado/httpserver.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,8 +378,29 @@ def headers_received(
# TODO: either make context an official part of the
# HTTPConnection interface or figure out some other way to do this.
self.connection.context._apply_xheaders(headers) # type: ignore
start_line, headers = self.apply_forwarded_context(start_line, headers)

return self.delegate.headers_received(start_line, headers)

def apply_forwarded_context(
self,
start_line: Union[httputil.RequestStartLine, httputil.ResponseStartLine],
headers: httputil.HTTPHeaders,
) -> Tuple[
Union[httputil.RequestStartLine, httputil.ResponseStartLine],
httputil.HTTPHeaders,
]:
"""Apply X-Forwarded-Context header to requested uri"""
if isinstance(start_line, httputil.RequestStartLine):
# get path from X-Forwarded-Context
proxy_path = headers.get("X-Forwarded-Context", None)
if proxy_path:
# preserve only the path part
path = proxy_path.split("?", 1)[0]
start_line = start_line._replace(path=path)

return start_line, headers

def data_received(self, chunk: bytes) -> Optional[Awaitable[None]]:
return self.delegate.data_received(chunk)

Expand Down
11 changes: 10 additions & 1 deletion tornado/test/httpserver_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def finish(self):

class HandlerBaseTestCase(AsyncHTTPTestCase):
def get_app(self):
return Application([("/", self.__class__.Handler)])
return Application([("/.*", self.__class__.Handler)])

def fetch_json(self, *args, **kwargs):
response = self.fetch(*args, **kwargs)
Expand Down Expand Up @@ -554,6 +554,7 @@ def get(self):
dict(
remote_ip=self.request.remote_ip,
remote_protocol=self.request.protocol,
path=self.request.path,
)
)

Expand Down Expand Up @@ -640,6 +641,14 @@ def test_scheme_headers(self):
self.fetch_json("/", headers=bad_forwarded)["remote_protocol"], "http"
)

def test_forwarded_context(self):
self.assertEqual(self.fetch_json("/")["path"], "/")

self.assertEqual(
self.fetch_json("/", headers={"X-Forwarded-Context": "/prefix"})["path"],
"/prefix",
)


class SSLXHeaderTest(AsyncHTTPSTestCase, HandlerBaseTestCase):
def get_app(self):
Expand Down