diff --git a/h2/config.py b/h2/config.py index 08129a406..47ab27a6d 100644 --- a/h2/config.py +++ b/h2/config.py @@ -6,6 +6,8 @@ Objects for controlling the configuration of the HTTP/2 stack. """ +import logging + class _BooleanConfigOption(object): """ @@ -34,7 +36,17 @@ class DummyLogger(object): logging functions when no logger is passed into the corresponding object. """ def __init__(self, *vargs): - pass + # Disable all logging + self.lvl = logging.CRITICAL + 1 + + def isEnabledFor(self, lvl): + """ + Dummy logger, so nothing is enabled. + """ + return lvl >= self.lvl + + def setLevel(self, lvl): + self.lvl = lvl def debug(self, *vargs, **kwargs): """ diff --git a/h2/connection.py b/h2/connection.py index 20975e3bb..822b9b256 100644 --- a/h2/connection.py +++ b/h2/connection.py @@ -6,6 +6,7 @@ An implementation of a HTTP/2 connection. """ import base64 +import logging from enum import Enum, IntEnum @@ -292,6 +293,7 @@ def __init__(self, config=None): self.encoder = Encoder() self.decoder = Decoder() + self._open_stream_counts = {0: 0, 1: 0} # This won't always actually do anything: for versions of HPACK older # than 2.3.0 it does nothing. However, we have to try! self.decoder.max_header_list_size = self.DEFAULT_MAX_HEADER_LIST_SIZE @@ -362,6 +364,8 @@ def __init__(self, config=None): size_limit=self.MAX_CLOSED_STREAMS ) + self._streams_to_close = list() + # The flow control window manager for the connection. self._inbound_flow_control_window_manager = WindowManager( max_window_size=self.local_settings.initial_window_size @@ -383,6 +387,13 @@ def __init__(self, config=None): ExtensionFrame: self._receive_unknown_frame } + def _increment_open_streams(self, stream_id, incr): + remainder = stream_id % 2 + self._open_stream_counts[remainder] += incr + + def _close_stream(self, stream_id): + self._streams_to_close.append(stream_id) + def _prepare_for_sending(self, frames): if not frames: return @@ -393,22 +404,15 @@ def _open_streams(self, remainder): """ A common method of counting number of open streams. Returns the number of streams that are open *and* that have (stream ID % 2) == remainder. - While it iterates, also deletes any closed streams. + Also cleans up closed streams. """ - count = 0 - to_delete = [] - - for stream_id, stream in self.streams.items(): - if stream.open and (stream_id % 2 == remainder): - count += 1 - elif stream.closed: - to_delete.append(stream_id) - - for stream_id in to_delete: + for stream_id in self._streams_to_close: stream = self.streams.pop(stream_id) + assert stream.closed self._closed_streams[stream_id] = stream.closed_by + self._streams_to_close = list() - return count + return self._open_stream_counts[remainder] @property def open_outbound_streams(self): @@ -467,14 +471,20 @@ def _begin_new_stream(self, stream_id, allowed_ids): stream_id, config=self.config, inbound_window_size=self.local_settings.initial_window_size, - outbound_window_size=self.remote_settings.initial_window_size + outbound_window_size=self.remote_settings.initial_window_size, + increment_open_stream_count_callback=self._increment_open_streams, + close_stream_callback=self._close_stream, ) self.config.logger.debug("Stream ID %d created", stream_id) s.max_inbound_frame_size = self.max_inbound_frame_size s.max_outbound_frame_size = self.max_outbound_frame_size self.streams[stream_id] = s - self.config.logger.debug("Current streams: %s", self.streams.keys()) + # Disable this log if we're not in debug mode, as it can be expensive + # when there are many concurrently open streams + if self.config.logger.isEnabledFor(logging.DEBUG): + self.config.logger.debug( + "Current streams: %s", self.streams.keys()) if outbound: self.highest_outbound_stream_id = stream_id @@ -1025,7 +1035,6 @@ def reset_stream(self, stream_id, error_code=0): def close_connection(self, error_code=0, additional_data=None, last_stream_id=None): - """ Close a connection, emitting a GOAWAY frame. @@ -1542,8 +1551,8 @@ def _receive_headers_frame(self, frame): max_open_streams = self.local_settings.max_concurrent_streams if (self.open_inbound_streams + 1) > max_open_streams: raise TooManyStreamsError( - "Max outbound streams is %d, %d open" % - (max_open_streams, self.open_outbound_streams) + "Max inbound streams is %d, %d open" % + (max_open_streams, self.open_inbound_streams) ) # Let's decode the headers. We handle headers as bytes internally up diff --git a/h2/stream.py b/h2/stream.py index 827e65a71..270b10160 100644 --- a/h2/stream.py +++ b/h2/stream.py @@ -90,6 +90,7 @@ class H2StreamStateMachine(object): :param stream_id: The stream ID of this stream. This is stored primarily for logging purposes. """ + def __init__(self, stream_id): self.state = StreamState.IDLE self.stream_id = stream_id @@ -767,6 +768,55 @@ def send_alt_svc(self, previous_state): (H2StreamStateMachine.send_on_closed_stream, StreamState.CLOSED), } +""" +Wraps a stream state change function to ensure that we keep +the parent H2Connection's state in sync +""" + + +def sync_state_change(func): + def wrapper(self, *args, **kwargs): + # Collect state at the beginning. + start_state = self.state_machine.state + started_open = self.open + started_closed = not started_open + + # Do the state change (if any). + result = func(self, *args, **kwargs) + + # Collect state at the end. + end_state = self.state_machine.state + ended_open = self.open + ended_closed = not ended_open + + # If at any point we've tranwsitioned to the CLOSED state + # from any other state, close our stream. + if end_state == StreamState.CLOSED and start_state != end_state: + if self._close_stream_callback: + self._close_stream_callback(self.stream_id) + # Clear callback so we only call this once per stream + self._close_stream_callback = None + + # If we were open, but are now closed, decrement + # the open stream count, and call the close callback. + if started_open and ended_closed: + if self._decrement_open_stream_count_callback: + self._decrement_open_stream_count_callback(self.stream_id, + -1,) + # Clear callback so we only call this once per stream + self._decrement_open_stream_count_callback = None + + # If we were closed, but are now open, increment + # the open stream count. + elif started_closed and ended_open: + if self._increment_open_stream_count_callback: + self._increment_open_stream_count_callback(self.stream_id, + 1,) + # Clear callback so we only call this once per stream + self._increment_open_stream_count_callback = None + return result + return wrapper + class H2Stream(object): """ @@ -778,22 +828,36 @@ class H2Stream(object): Attempts to create frames that cannot be sent will raise a ``ProtocolError``. """ + def __init__(self, stream_id, config, inbound_window_size, - outbound_window_size): + outbound_window_size, + increment_open_stream_count_callback, + close_stream_callback,): self.state_machine = H2StreamStateMachine(stream_id) self.stream_id = stream_id self.max_outbound_frame_size = None self.request_method = None - # The current value of the outbound stream flow control window + # The current value of the outbound stream flow control window. self.outbound_flow_control_window = outbound_window_size # The flow control manager. self._inbound_window_manager = WindowManager(inbound_window_size) + # Callback to increment open stream count for the H2Connection. + self._increment_open_stream_count_callback = \ + increment_open_stream_count_callback + + # Callback to decrement open stream count for the H2Connection. + self._decrement_open_stream_count_callback = \ + increment_open_stream_count_callback + + # Callback to clean up state for the H2Connection once we're closed. + self._close_stream_callback = close_stream_callback + # The expected content length, if any. self._expected_content_length = None @@ -850,6 +914,7 @@ def closed_by(self): """ return self.state_machine.stream_closed_by + @sync_state_change def upgrade(self, client_side): """ Called by the connection to indicate that this stream is the initial @@ -868,6 +933,7 @@ def upgrade(self, client_side): self.state_machine.process_input(input_) return + @sync_state_change def send_headers(self, headers, encoder, end_stream=False): """ Returns a list of HEADERS/CONTINUATION frames to emit as either headers @@ -917,6 +983,7 @@ def send_headers(self, headers, encoder, end_stream=False): return frames + @sync_state_change def push_stream_in_band(self, related_stream_id, headers, encoder): """ Returns a list of PUSH_PROMISE/CONTINUATION frames to emit as a pushed @@ -941,6 +1008,7 @@ def push_stream_in_band(self, related_stream_id, headers, encoder): return frames + @sync_state_change def locally_pushed(self): """ Mark this stream as one that was pushed by this peer. Must be called @@ -954,6 +1022,7 @@ def locally_pushed(self): assert not events return [] + @sync_state_change def send_data(self, data, end_stream=False, pad_length=None): """ Prepare some data frames. Optionally end the stream. @@ -981,6 +1050,7 @@ def send_data(self, data, end_stream=False, pad_length=None): return [df] + @sync_state_change def end_stream(self): """ End a stream without sending data. @@ -992,6 +1062,7 @@ def end_stream(self): df.flags.add('END_STREAM') return [df] + @sync_state_change def advertise_alternative_service(self, field_value): """ Advertise an RFC 7838 alternative service. The semantics of this are @@ -1005,6 +1076,7 @@ def advertise_alternative_service(self, field_value): asf.field = field_value return [asf] + @sync_state_change def increase_flow_control_window(self, increment): """ Increase the size of the flow control window for the remote side. @@ -1020,6 +1092,7 @@ def increase_flow_control_window(self, increment): wuf.window_increment = increment return [wuf] + @sync_state_change def receive_push_promise_in_band(self, promised_stream_id, headers, @@ -1044,6 +1117,7 @@ def receive_push_promise_in_band(self, ) return [], events + @sync_state_change def remotely_pushed(self, pushed_headers): """ Mark this stream as one that was pushed by the remote peer. Must be @@ -1057,6 +1131,7 @@ def remotely_pushed(self, pushed_headers): self._authority = authority_from_headers(pushed_headers) return [], events + @sync_state_change def receive_headers(self, headers, end_stream, header_encoding): """ Receive a set of headers (or trailers). @@ -1091,6 +1166,7 @@ def receive_headers(self, headers, end_stream, header_encoding): ) return [], events + @sync_state_change def receive_data(self, data, end_stream, flow_control_len): """ Receive some data. @@ -1114,6 +1190,7 @@ def receive_data(self, data, end_stream, flow_control_len): events[0].flow_controlled_length = flow_control_len return [], events + @sync_state_change def receive_window_update(self, increment): """ Handle a WINDOW_UPDATE increment. @@ -1150,6 +1227,7 @@ def receive_window_update(self, increment): return frames, events + @sync_state_change def receive_continuation(self): """ A naked CONTINUATION frame has been received. This is always an error, @@ -1162,6 +1240,7 @@ def receive_continuation(self): ) assert False, "Should not be reachable" + @sync_state_change def receive_alt_svc(self, frame): """ An Alternative Service frame was received on the stream. This frame @@ -1189,6 +1268,7 @@ def receive_alt_svc(self, frame): return [], events + @sync_state_change def reset_stream(self, error_code=0): """ Close the stream locally. Reset the stream with an error code. @@ -1202,6 +1282,7 @@ def reset_stream(self, error_code=0): rsf.error_code = error_code return [rsf] + @sync_state_change def stream_reset(self, frame): """ Handle a stream being reset remotely. @@ -1217,6 +1298,7 @@ def stream_reset(self, frame): return [], events + @sync_state_change def acknowledge_received_data(self, acknowledged_size): """ The user has informed us that they've processed some amount of data diff --git a/test/test_basic_logic.py b/test/test_basic_logic.py index 7df99a6a5..adecad2e4 100644 --- a/test/test_basic_logic.py +++ b/test/test_basic_logic.py @@ -1851,7 +1851,7 @@ def test_stream_repr(self): """ Ensure stream string representation is appropriate. """ - s = h2.stream.H2Stream(4, None, 12, 14) + s = h2.stream.H2Stream(4, None, 12, 14, None, None) assert repr(s) == ">" diff --git a/test/test_concurrent_stream_open.py b/test/test_concurrent_stream_open.py new file mode 100644 index 000000000..d95a7287c --- /dev/null +++ b/test/test_concurrent_stream_open.py @@ -0,0 +1,82 @@ +# -*- coding: utf-8 -*- +""" +test_flow_control +~~~~~~~~~~~~~~~~~ + +Tests of the flow control management in h2 +""" +import logging +import time + + +import h2.config +import h2.connection +import h2.errors +import h2.events +import h2.exceptions +import h2.settings +from h2.stream import H2Stream, sync_state_change + + +class TestConcurrentStreamOpen(object): + """ + Tests the performance of concurrently opening streams + """ + example_request_headers = [ + (':authority', 'example.com'), + (':path', '/'), + (':scheme', 'https'), + (':method', 'GET'), + ] + server_config = h2.config.H2Configuration(client_side=False) + client_config = h2.config.H2Configuration(client_side=True) + + DEFAULT_FLOW_WINDOW = 65535 + + def test_sync_state_change_incr_conditional(self, frame_factory): + + @sync_state_change + def wrap_send_headers(self, *args, **kwargs): + return self.send_headers(*args, **kwargs) + + def dummy_callback(*args, **kwargs): + pass + + c = h2.connection.H2Connection() + s = H2Stream(1, self.client_config, self.DEFAULT_FLOW_WINDOW, + self.DEFAULT_FLOW_WINDOW, dummy_callback, + dummy_callback) + s.max_outbound_frame_size = 65536 + + wrap_send_headers(s, self.example_request_headers, + c.encoder, end_stream=False) + assert s.open + + def test_concurrent_stream_open_performance(self, frame_factory): + """ + Opening many concurrent streams isn't prohibitively expensive + """ + num_concurrent_streams = 10000 + + c = h2.connection.H2Connection() + c.initiate_connection() + start = time.time() + for i in range(num_concurrent_streams): + c.send_headers( + 1 + (2 * i), self.example_request_headers, end_stream=False) + c.clear_outbound_data_buffer() + end = time.time() + + run_time = end - start + assert run_time < 5 + + def test_stream_open_with_debug_logging(self, frame_factory): + """ + Test that opening a stream with debug logging works + """ + c = h2.connection.H2Connection() + c.initiate_connection() + c.config.logger.setLevel(logging.DEBUG) + c.send_headers( + 1, self.example_request_headers, end_stream=False) + c.clear_outbound_data_buffer()