Skip to content
This repository was archived by the owner on Jan 13, 2021. It is now read-only.

Add ENABLE_PUSH flag in the Upgrade HTTP2-Settings header #310

Open
wants to merge 11 commits into
base: development
Choose a base branch
from
19 changes: 18 additions & 1 deletion hyper/common/connection.py
Original file line number Diff line number Diff line change
@@ -62,7 +62,8 @@ def __init__(self,
self._port = port
self._h1_kwargs = {
'secure': secure, 'ssl_context': ssl_context,
'proxy_host': proxy_host, 'proxy_port': proxy_port
'proxy_host': proxy_host, 'proxy_port': proxy_port,
'enable_push': enable_push
}
self._h2_kwargs = {
'window_manager': window_manager, 'enable_push': enable_push,
@@ -143,6 +144,22 @@ def get_response(self, *args, **kwargs):

return self._conn.get_response(1)

def get_pushes(self, *args, **kwargs):
try:
return self._conn.get_pushes(*args, **kwargs)
except HTTPUpgrade as e:
assert e.negotiated == H2C_PROTOCOL

self._conn = HTTP20Connection(
self._host, self._port, **self._h2_kwargs
)

self._conn._connect_upgrade(e.sock, True)
# stream id 1 is used by the upgrade request and response
# and is half-closed by the client

return self._conn.get_pushes(*args, **kwargs)

# The following two methods are the implementation of the context manager
# protocol.
def __enter__(self): # pragma: no cover
63 changes: 42 additions & 21 deletions hyper/http11/connection.py
Original file line number Diff line number Diff line change
@@ -58,6 +58,7 @@ class HTTP11Connection(object):
"""

version = HTTPVersion.http11
_response = None

def __init__(self, host, port=None, secure=None, ssl_context=None,
proxy_host=None, proxy_port=None, **kwargs):
@@ -78,6 +79,7 @@ def __init__(self, host, port=None, secure=None, ssl_context=None,

# only send http upgrade headers for non-secure connection
self._send_http_upgrade = not self.secure
self._enable_push = kwargs.get('enable_push')

self.ssl_context = ssl_context
self._sock = None
@@ -104,6 +106,12 @@ def __init__(self, host, port=None, secure=None, ssl_context=None,
#: the standard hyper parsing interface.
self.parser = Parser()

def get_pushes(self, stream_id=None, capture_all=False):
"""
Dummy method to trigger h2c upgrade.
"""
self._get_response()

def connect(self):
"""
Connect to the server specified when the object was created. This is a
@@ -188,6 +196,7 @@ def request(self, method, url, body=None, headers=None):
# Next, send the request body.
if body:
self._send_body(body, body_type)
self._response = None

return

@@ -198,31 +207,39 @@ def get_response(self):
This is an early beta, so the response object is pretty stupid. That's
ok, we'll fix it later.
"""
headers = HTTPHeaderMap()
resp = self._get_response()
self._response = None
return resp

response = None
while response is None:
# 'encourage' the socket to receive data.
self._sock.fill()
response = self.parser.parse_response(self._sock.buffer)
def _get_response(self):
if self._response is None:

for n, v in response.headers:
headers[n.tobytes()] = v.tobytes()
headers = HTTPHeaderMap()

self._sock.advance_buffer(response.consumed)
response = None
while response is None:
# 'encourage' the socket to receive data.
self._sock.fill()
response = self.parser.parse_response(self._sock.buffer)

if (response.status == 101 and
for n, v in response.headers:
headers[n.tobytes()] = v.tobytes()

self._sock.advance_buffer(response.consumed)

if (response.status == 101 and
b'upgrade' in headers['connection'] and
H2C_PROTOCOL.encode('utf-8') in headers['upgrade']):
raise HTTPUpgrade(H2C_PROTOCOL, self._sock)

return HTTP11Response(
response.status,
response.msg.tobytes(),
headers,
self._sock,
self
)
H2C_PROTOCOL.encode('utf-8') in headers['upgrade']):
raise HTTPUpgrade(H2C_PROTOCOL, self._sock)

self._response = HTTP11Response(
response.status,
response.msg.tobytes(),
headers,
self._sock,
self
)
return self._response

def _send_headers(self, method, url, headers):
"""
@@ -276,6 +293,10 @@ def _add_upgrade_headers(self, headers):
# Settings header.
http2_settings = SettingsFrame(0)
http2_settings.settings[SettingsFrame.INITIAL_WINDOW_SIZE] = 65535
if self._enable_push is not None:
http2_settings.settings[SettingsFrame.ENABLE_PUSH] = (
int(self._enable_push)
)
encoded_settings = base64.urlsafe_b64encode(
http2_settings.serialize_body()
)
@@ -348,7 +369,7 @@ def _send_file_like_obj(self, fobj):
Handles streaming a file-like object to the network.
"""
while True:
block = fobj.read(16*1024)
block = fobj.read(16 * 1024)
if not block:
break

11 changes: 9 additions & 2 deletions hyper/http20/connection.py
Original file line number Diff line number Diff line change
@@ -114,6 +114,7 @@ def __init__(self, host, port=None, secure=None, window_manager=None,
else:
self.secure = False

self._delay_recv = False
self._enable_push = enable_push
self.ssl_context = ssl_context

@@ -313,6 +314,9 @@ def get_response(self, stream_id=None):
get a response.
:returns: A :class:`HTTP20Response <hyper.HTTP20Response>` object.
"""
if self._delay_recv:
self._recv_cb()
self._delay_recv = False
stream = self._get_stream(stream_id)
return HTTP20Response(stream.getheaders(), stream)

@@ -384,7 +388,7 @@ def connect(self):

self._send_preamble()

def _connect_upgrade(self, sock):
def _connect_upgrade(self, sock, no_recv=False):
"""
Called by the generic HTTP connection when we're being upgraded. Locks
in a new socket and places the backing state machine into an upgrade
@@ -405,7 +409,10 @@ def _connect_upgrade(self, sock):
s = self._new_stream(local_closed=True)
self.recent_stream = s

self._recv_cb()
if no_recv: # To delay I/O operation
self._delay_recv = True
else:
self._recv_cb()

def _send_preamble(self):
"""
1 change: 1 addition & 0 deletions test/test_abstraction.py
Original file line number Diff line number Diff line change
@@ -19,6 +19,7 @@ def test_h1_kwargs(self):
'proxy_host': False,
'proxy_port': False,
'other_kwarg': True,
'enable_push': True,
}

def test_h2_kwargs(self):
33 changes: 28 additions & 5 deletions test/test_hyper.py
Original file line number Diff line number Diff line change
@@ -9,6 +9,7 @@
PingFrame, FRAME_MAX_ALLOWED_LEN
)
from hpack.hpack_compat import Encoder
from hyper import HTTPConnection
from hyper.http20.connection import HTTP20Connection
from hyper.http20.response import HTTP20Response, HTTP20Push
from hyper.http20.exceptions import ConnectionError, StreamResetError
@@ -731,8 +732,8 @@ def add_data_frame(self, stream_id, data, end_stream=False):
frame.flags.add('END_STREAM')
self.frames.append(frame)

def request(self):
self.conn = HTTP20Connection('www.google.com', enable_push=True)
def request(self, enable_push=True):
self.conn = HTTP20Connection('www.google.com', enable_push=enable_push)
self.conn._sock = DummySocket()
self.conn._sock.buffer = BytesIO(
b''.join([frame.serialize() for frame in self.frames])
@@ -934,13 +935,13 @@ def test_reset_pushed_streams_when_push_disabled(self):
1, [(':status', '200'), ('content-type', 'text/html')]
)

self.request()
self.conn._enable_push = False
self.request(enable_push=False)
self.conn.get_response()

f = RstStreamFrame(2)
f.error_code = 7
assert self.conn._sock.queue[-1] == f.serialize()
print(self.conn._sock.queue)
assert self.conn._sock.queue[-1].endswith(f.serialize())

def test_pushed_requests_ignore_unexpected_headers(self):
headers = HTTPHeaderMap([
@@ -956,7 +957,29 @@ def test_pushed_requests_ignore_unexpected_headers(self):
assert p.request_headers == HTTPHeaderMap([('no', 'no')])


class TestUpgradingPush(TestServerPush):
http101 = (b"HTTP/1.1 101 Switching Protocols\r\n"
b"Connection: upgrade\r\n"
b"Upgrade: h2c\r\n"
b"\r\n")

def setup_method(self, method):
self.frames = [SettingsFrame(0)] # Server-side preface
self.encoder = Encoder()
self.conn = None

def request(self, enable_push=True):
self.conn = HTTPConnection('www.google.com', enable_push=enable_push)
self.conn._conn._sock = DummySocket()
self.conn._conn._sock.buffer = BytesIO(
self.http101 + b''.join([frame.serialize()
for frame in self.frames])
)
self.conn.request('GET', '/')


class TestResponse(object):

def test_status_is_stripped_from_headers(self):
headers = HTTPHeaderMap([(':status', '200')])
resp = HTTP20Response(headers, None)
119 changes: 82 additions & 37 deletions test/test_integration.py
Original file line number Diff line number Diff line change
@@ -12,6 +12,7 @@
import hyper
import hyper.http11.connection
import pytest
from contextlib import contextmanager
from mock import patch
from h2.frame_buffer import FrameBuffer
from hyper.compat import ssl
@@ -64,17 +65,30 @@ def frame_buffer():
return buffer


@contextmanager
def reusable_frame_buffer(buffer):
# FrameBuffer does not return new iterator for iteration.
data = buffer.data
yield buffer
buffer.data = data


def receive_preamble(sock):
# Receive the HTTP/2 'preamble'.
first = sock.recv(65535)
client_preface = b'PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n'
timeout = time.time() + 5
got = b''
while len(got) < len(client_preface) and time.time() < timeout:
got += sock.recv(len(client_preface) - len(got))

assert got == client_preface, "client preface mismatch"

# Work around some bugs: if the first message received was only the PRI
# string, aim to receive a settings frame as well.
if len(first) <= len(b'PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n'):
sock.recv(65535)
# Send server side HTTP/2 preface
sock.send(SettingsFrame(0).serialize())
sock.recv(65535)
return
# Drain to let the client proceed.
# Note that in the lower socket level, this method is not
# just doing "receive".
return sock.recv(65535)


@patch('hyper.http20.connection.H2_NPN_PROTOCOLS', PROTOCOLS)
@@ -138,7 +152,7 @@ def socket_handler(listener):
self._start_server(socket_handler)
conn = self.get_connection()
conn.connect()
send_event.wait()
send_event.wait(5)

# Get the chunk of data after the preamble and decode it into frames.
# We actually expect two, but only the second one contains ENABLE_PUSH.
@@ -242,7 +256,7 @@ def socket_handler(listener):
f = SettingsFrame(0)
sock.send(f.serialize())

send_event.wait()
send_event.wait(5)
sock.recv(65535)
sock.close()

@@ -260,6 +274,7 @@ def socket_handler(listener):
def test_closed_responses_remove_their_streams_from_conn(self):
self.set_up()

req_event = threading.Event()
recv_event = threading.Event()

def socket_handler(listener):
@@ -270,6 +285,8 @@ def socket_handler(listener):
receive_preamble(sock)
sock.recv(65535)

# Wait for request
req_event.wait(5)
# Now, send the headers for the response.
f = build_headers_frame([(':status', '200')])
f.stream_id = 1
@@ -282,6 +299,7 @@ def socket_handler(listener):
self._start_server(socket_handler)
conn = self.get_connection()
conn.request('GET', '/')
req_event.set()
resp = conn.get_response()

# Close the response.
@@ -296,6 +314,7 @@ def socket_handler(listener):
def test_receiving_responses_with_no_body(self):
self.set_up()

req_event = threading.Event()
recv_event = threading.Event()

def socket_handler(listener):
@@ -306,6 +325,8 @@ def socket_handler(listener):
receive_preamble(sock)
sock.recv(65535)

# Wait for request
req_event.wait(5)
# Now, send the headers for the response. This response has no body
f = build_headers_frame(
[(':status', '204'), ('content-length', '0')]
@@ -321,6 +342,7 @@ def socket_handler(listener):
self._start_server(socket_handler)
conn = self.get_connection()
conn.request('GET', '/')
req_event.set()
resp = conn.get_response()

# Confirm the status code.
@@ -338,6 +360,7 @@ def socket_handler(listener):
def test_receiving_trailers(self):
self.set_up()

req_event = threading.Event()
recv_event = threading.Event()

def socket_handler(listener):
@@ -350,6 +373,8 @@ def socket_handler(listener):
receive_preamble(sock)
sock.recv(65535)

# Wait for request
req_event.wait(5)
# Now, send the headers for the response.
f = build_headers_frame(
[(':status', '200'), ('content-length', '14')],
@@ -372,12 +397,13 @@ def socket_handler(listener):
sock.send(f.serialize())

# Wait for the message from the main thread.
recv_event.set()
recv_event.wait(5)
sock.close()

self._start_server(socket_handler)
conn = self.get_connection()
conn.request('GET', '/')
req_event.set()
resp = conn.get_response()

# Confirm the status code.
@@ -396,13 +422,14 @@ def socket_handler(listener):
assert len(resp.trailers) == 2

# Awesome, we're done now.
recv_event.wait(5)
recv_event.set()

self.tear_down()

def test_receiving_trailers_before_reading(self):
self.set_up()

req_event = threading.Event()
recv_event = threading.Event()
wait_event = threading.Event()

@@ -416,6 +443,8 @@ def socket_handler(listener):
receive_preamble(sock)
sock.recv(65535)

# Wait for request
req_event.wait(5)
# Now, send the headers for the response.
f = build_headers_frame(
[(':status', '200'), ('content-length', '14')],
@@ -449,6 +478,7 @@ def socket_handler(listener):
self._start_server(socket_handler)
conn = self.get_connection()
conn.request('GET', '/')
req_event.set()
resp = conn.get_response()

# Confirm the status code.
@@ -647,6 +677,7 @@ def test_resetting_stream_with_frames_in_flight(self):
"""
self.set_up()

req_event = threading.Event()
recv_event = threading.Event()

def socket_handler(listener):
@@ -657,6 +688,8 @@ def socket_handler(listener):
receive_preamble(sock)
sock.recv(65535)

# Wait for request
req_event.wait(5)
# Now, send the headers for the response. This response has no
# body.
f = build_headers_frame(
@@ -673,6 +706,7 @@ def socket_handler(listener):
self._start_server(socket_handler)
conn = self.get_connection()
stream_id = conn.request('GET', '/')
req_event.set()

# Now, trigger the RST_STREAM frame by closing the stream.
conn._send_rst_frame(stream_id, 0)
@@ -696,6 +730,7 @@ def test_stream_can_be_reset_multiple_times(self):
"""
self.set_up()

req_event = threading.Event()
recv_event = threading.Event()

def socket_handler(listener):
@@ -706,6 +741,8 @@ def socket_handler(listener):
receive_preamble(sock)
sock.recv(65535)

# Wait for request
req_event.wait(5)
# Now, send two RST_STREAM frames.
for _ in range(0, 2):
f = RstStreamFrame(1)
@@ -718,6 +755,7 @@ def socket_handler(listener):
self._start_server(socket_handler)
conn = self.get_connection()
conn.request('GET', '/')
req_event.set()

# Now, eat the Rst frames. These should not cause an exception.
conn._single_read()
@@ -737,6 +775,7 @@ def socket_handler(listener):
def test_read_chunked_http2(self):
self.set_up()

req_event = threading.Event()
recv_event = threading.Event()
wait_event = threading.Event()

@@ -748,6 +787,8 @@ def socket_handler(listener):
receive_preamble(sock)
sock.recv(65535)

# Wait for request
req_event.wait(5)
# Now, send the headers for the response. This response has a body.
f = build_headers_frame([(':status', '200')])
f.stream_id = 1
@@ -777,6 +818,7 @@ def socket_handler(listener):
self._start_server(socket_handler)
conn = self.get_connection()
conn.request('GET', '/')
req_event.set()
resp = conn.get_response()

# Confirm the status code.
@@ -805,6 +847,7 @@ def socket_handler(listener):
def test_read_delayed(self):
self.set_up()

req_event = threading.Event()
recv_event = threading.Event()
wait_event = threading.Event()

@@ -816,6 +859,8 @@ def socket_handler(listener):
receive_preamble(sock)
sock.recv(65535)

# Wait for request
req_event.wait(5)
# Now, send the headers for the response. This response has a body.
f = build_headers_frame([(':status', '200')])
f.stream_id = 1
@@ -845,6 +890,7 @@ def socket_handler(listener):
self._start_server(socket_handler)
conn = self.get_connection()
conn.request('GET', '/')
req_event.set()
resp = conn.get_response()

# Confirm the status code.
@@ -958,16 +1004,15 @@ def socket_handler(listener):

receive_preamble(sock)

# Wait for the message from the main thread.
send_event.wait()
# Send the headers for the response. This response has no body.
f = build_headers_frame(
[(':status', '200'), ('content-length', '0')]
)
f.flags.add('END_STREAM')
f.stream_id = 1
sock.sendall(f.serialize())

# Wait for the message from the main thread.
send_event.wait()
sock.close()

self._start_server(socket_handler)
@@ -996,7 +1041,7 @@ def socket_handler(listener):
data += sock.recv(65535)
assert b'upgrade: h2c\r\n' in data

send_event.wait()
send_event.wait(5)

# We need to send back a response.
resp = (
@@ -1038,7 +1083,7 @@ class TestRequestsAdapter(SocketLevelTest):
# This uses HTTP/2.
h2 = True

def test_adapter_received_values(self, monkeypatch):
def test_adapter_received_values(self, monkeypatch, frame_buffer):
self.set_up()

# We need to patch the ssl_wrap_socket method to ensure that we
@@ -1051,17 +1096,20 @@ def wrap(*args):

monkeypatch.setattr(hyper.http11.connection, 'wrap_socket', wrap)

data = []
send_event = threading.Event()

def socket_handler(listener):
sock = listener.accept()[0]

# Do the handshake: conn header, settings, send settings, recv ack.
receive_preamble(sock)
frame_buffer.add_data(receive_preamble(sock))

# Now expect some data. One headers frame.
data.append(sock.recv(65535))
req_wait = True
while req_wait:
frame_buffer.add_data(sock.recv(65535))
with reusable_frame_buffer(frame_buffer) as fr:
for f in fr:
if isinstance(f, HeadersFrame):
req_wait = False

# Respond!
h = HeadersFrame(1)
@@ -1078,8 +1126,6 @@ def socket_handler(listener):
d.data = b'1234567890' * 2
d.flags.add('END_STREAM')
sock.send(d.serialize())

send_event.wait(5)
sock.close()

self._start_server(socket_handler)
@@ -1093,11 +1139,9 @@ def socket_handler(listener):
assert r.headers[b'Content-Type'] == b'not/real'
assert r.content == b'1234567890' * 2

send_event.set()

self.tear_down()

def test_adapter_sending_values(self, monkeypatch):
def test_adapter_sending_values(self, monkeypatch, frame_buffer):
self.set_up()

# We need to patch the ssl_wrap_socket method to ensure that we
@@ -1110,17 +1154,20 @@ def wrap(*args):

monkeypatch.setattr(hyper.http11.connection, 'wrap_socket', wrap)

data = []

def socket_handler(listener):
sock = listener.accept()[0]

# Do the handshake: conn header, settings, send settings, recv ack.
receive_preamble(sock)
frame_buffer.add_data(receive_preamble(sock))

# Now expect some data. One headers frame and one data frame.
data.append(sock.recv(65535))
data.append(sock.recv(65535))
req_wait = True
while req_wait:
frame_buffer.add_data(sock.recv(65535))
with reusable_frame_buffer(frame_buffer) as fr:
for f in fr:
if isinstance(f, DataFrame):
req_wait = False

# Respond!
h = HeadersFrame(1)
@@ -1137,7 +1184,6 @@ def socket_handler(listener):
d.data = b'1234567890' * 2
d.flags.add('END_STREAM')
sock.send(d.serialize())

sock.close()

self._start_server(socket_handler)
@@ -1152,11 +1198,10 @@ def socket_handler(listener):
# Assert about the sent values.
assert r.status_code == 200

f = decode_frame(data[0])
assert isinstance(f, HeadersFrame)
frames = list(frame_buffer)
assert isinstance(frames[-2], HeadersFrame)

f = decode_frame(data[1])
assert isinstance(f, DataFrame)
assert f.data == b'hi there'
assert isinstance(frames[-1], DataFrame)
assert frames[-1].data == b'hi there'

self.tear_down()