From a32cac862cf746186c2a451919d6f01b4d1111c0 Mon Sep 17 00:00:00 2001 From: Javad Asgari Shafique Date: Wed, 23 Oct 2024 03:34:49 +0200 Subject: [PATCH] Address https://github.com/python/cpython/issues/118950 in uvloop by porting fix and adding tests to ensure asyncio.streams code effectively can schedule connection_lost and raise ConnectionResetError --- tests/test_aiohttp.py | 61 ++++++++++++++++++++++++++++++++++++++++++- tests/test_tcp.py | 52 ++++++++++++++++++++++++++++++++++++ uvloop/sslproto.pyx | 5 +++- 3 files changed, 116 insertions(+), 2 deletions(-) diff --git a/tests/test_aiohttp.py b/tests/test_aiohttp.py index 514d0177..b67e10e4 100644 --- a/tests/test_aiohttp.py +++ b/tests/test_aiohttp.py @@ -7,6 +7,7 @@ skip_tests = False import asyncio +import os import sys import unittest import weakref @@ -14,7 +15,7 @@ from uvloop import _testbase as tb -class _TestAioHTTP: +class _TestAioHTTP(tb.SSLTestCase): def test_aiohttp_basic_1(self): @@ -115,6 +116,64 @@ async def stop(): self.loop.run_until_complete(stop()) + def test_aiohttp_connection_lost_when_busy(self): + if self.implementation == 'asyncio': + raise unittest.SkipTest('bug in asyncio #118950 tests in CPython.') + + cert = tb._cert_fullname(__file__, 'ssl_cert.pem') + key = tb._cert_fullname(__file__, 'ssl_key.pem') + ssl_context = self._create_server_ssl_context(cert, key) + client_ssl_context = self._create_client_ssl_context() + + asyncio.set_event_loop(self.loop) + app = aiohttp.web.Application() + + async def handler(request): + ws = aiohttp.web.WebSocketResponse() + await ws.prepare(request) + async for msg in ws: + print("Received:", msg.data) + return ws + + app.router.add_get('/', handler) + + runner = aiohttp.web.AppRunner(app) + self.loop.run_until_complete(runner.setup()) + host = '0.0.0.0' + site = aiohttp.web.TCPSite(runner, host, '0', ssl_context=ssl_context) + self.loop.run_until_complete(site.start()) + port = site._server.sockets[0].getsockname()[1] + session = aiohttp.ClientSession(loop=self.loop) + + async def test(): + async with session.ws_connect( + f"wss://{host}:{port}/", + ssl=client_ssl_context + ) as ws: + transport = ws._writer.transport + s = transport.get_extra_info('socket') + + if self.implementation == 'asyncio': + s._sock.close() + else: + os.close(s.fileno()) + + # FLOW_CONTROL_HIGH_WATER * 1024 + bytes_to_send = 64 * 1024 + iterations = 10 + msg = b'Hello world, still there?' + + # Send enough messages to trigger a socket write + one extra + for _ in range(iterations + 1): + await ws.send_bytes( + msg * ((bytes_to_send // len(msg)) // iterations)) + + self.assertRaises( + ConnectionResetError, self.loop.run_until_complete, test()) + + self.loop.run_until_complete(session.close()) + self.loop.run_until_complete(runner.cleanup()) + @unittest.skipIf(skip_tests, "no aiohttp module") class Test_UV_AioHTTP(_TestAioHTTP, tb.UVTestCase): diff --git a/tests/test_tcp.py b/tests/test_tcp.py index 8759383d..aabb45fc 100644 --- a/tests/test_tcp.py +++ b/tests/test_tcp.py @@ -1,5 +1,6 @@ import asyncio import asyncio.sslproto +import contextlib import gc import os import select @@ -3192,6 +3193,57 @@ async def run_main(): self.loop.run_until_complete(run_main()) + def test_connection_lost_when_busy(self): + if self.implementation == 'asyncio': + raise unittest.SkipTest('bug in asyncio #118950 tests in CPython.') + + ssl_context = self._create_server_ssl_context( + self.ONLYCERT, self.ONLYKEY) + client_ssl_context = self._create_client_ssl_context() + port = tb.find_free_port() + + @contextlib.asynccontextmanager + async def server(): + async def client_handler(reader, writer): + ... + + srv = await asyncio.start_server( + client_handler, '0.0.0.0', + port, ssl=ssl_context, reuse_port=True) + + try: + yield + finally: + srv.close() + + async def client(): + reader, writer = await asyncio.open_connection( + '0.0.0.0', port, ssl=client_ssl_context) + transport = writer.transport + s = transport.get_extra_info('socket') + + if self.implementation == 'asyncio': + s._sock.close() + else: + os.close(s.fileno()) + + # FLOW_CONTROL_HIGH_WATER * 1024 + bytes_to_send = 64 * 1024 + iterations = 10 + msg = b'An really important message :)' + + # Busy drain loop + for _ in range(iterations + 1): + writer.write(msg * ((bytes_to_send // len(msg)) // iterations)) + await writer.drain() + + async def test(): + async with server(): + await client() + + self.assertRaises( + ConnectionResetError, self.loop.run_until_complete, test()) + class Test_UV_TCPSSL(_TestSSL, tb.UVTestCase): pass diff --git a/uvloop/sslproto.pyx b/uvloop/sslproto.pyx index 42bb7644..fd37738a 100644 --- a/uvloop/sslproto.pyx +++ b/uvloop/sslproto.pyx @@ -37,7 +37,7 @@ cdef class _SSLProtocolTransport: return self._ssl_protocol._app_protocol def is_closing(self): - return self._closed + return self._closed or self._ssl_protocol._is_transport_closing() def close(self): """Close the transport. @@ -316,6 +316,9 @@ cdef class SSLProtocol: self._app_transport_created = True return self._app_transport + def _is_transport_closing(self): + return self._transport is not None and self._transport.is_closing() + def connection_made(self, transport): """Called when the low-level connection is made.