From e0f6d5e62efe0e04fb68c1d5e547deee9457ee92 Mon Sep 17 00:00:00 2001 From: Scott Shawcroft Date: Wed, 4 Nov 2020 18:17:34 -0800 Subject: [PATCH 1/3] Better handle errors by retrying * Handle send failures * Handle send exception from ESP32SPI * Handle failed response read * Handle connect() failures * Handle runtime error from ESP32SPI on connect --- adafruit_requests.py | 95 +++++++++++++++++++++++++++----------------- 1 file changed, 59 insertions(+), 36 deletions(-) mode change 100755 => 100644 adafruit_requests.py diff --git a/adafruit_requests.py b/adafruit_requests.py old mode 100755 new mode 100644 index fc41a8b..722884e --- a/adafruit_requests.py +++ b/adafruit_requests.py @@ -53,6 +53,7 @@ __version__ = "0.0.0-auto.0" __repo__ = "https://github.com/adafruit/Adafruit_CircuitPython_Requests.git" +import errno class _RawResponse: def __init__(self, response): @@ -72,6 +73,8 @@ def readinto(self, buf): into buf.""" return self._response._readinto(buf) # pylint: disable=protected-access +class _SendFailed(Exception): + """Custom exception to abort sending a request.""" class Response: """The response from a request, contains all the headers/content""" @@ -94,11 +97,13 @@ def __init__(self, sock, session=None): self._chunked = False self._backwards_compatible = not hasattr(sock, "recv_into") - if self._backwards_compatible: - print("Socket missing recv_into. Using more memory to be compatible") http = self._readto(b" ") if not http: + if session: + session._close_socket(self.socket) + else: + self.socket.close() raise RuntimeError("Unable to read HTTP response.") self.status_code = int(bytes(self._readto(b" "))) self.reason = self._readto(b"\r\n") @@ -414,30 +419,39 @@ def _get_socket(self, host, port, proto, *, timeout=1): addr_info = self._socket_pool.getaddrinfo( host, port, 0, self._socket_pool.SOCK_STREAM )[0] - sock = self._socket_pool.socket(addr_info[0], addr_info[1], addr_info[2]) - connect_host = addr_info[-1][0] - if proto == "https:": - sock = self._ssl_context.wrap_socket(sock, server_hostname=host) - connect_host = host - sock.settimeout(timeout) # socket read timeout - ok = True - try: - ok = sock.connect((connect_host, port)) - except MemoryError: - if not any(self._socket_free.items()): - raise - ok = False - - # We couldn't connect due to memory so clean up the open sockets. - if not ok: - self._free_sockets() - # Recreate the socket because the ESP-IDF won't retry the connection if it failed once. - sock = None # Clear first so the first socket can be cleaned up. - sock = self._socket_pool.socket(addr_info[0], addr_info[1], addr_info[2]) + retry_count = 0 + sock = None + while retry_count < 5 and sock is None: + if retry_count > 0: + if any(self._socket_free.items()): + self._free_sockets() + else: + raise RuntimeError("Out of sockets") + retry_count += 1 + + try: + sock = self._socket_pool.socket(addr_info[0], addr_info[1], addr_info[2]) + except OSError: + continue + + connect_host = addr_info[-1][0] if proto == "https:": sock = self._ssl_context.wrap_socket(sock, server_hostname=host) + connect_host = host sock.settimeout(timeout) # socket read timeout - sock.connect((connect_host, port)) + + try: + sock.connect((connect_host, port)) + except MemoryError: + sock.close() + sock = None + except OSError as e: + sock.close() + sock = None + + if sock is None: + raise RuntimeError("Repeated socket failures") + self._open_sockets[key] = sock self._socket_free[sock] = False return sock @@ -446,11 +460,15 @@ def _get_socket(self, host, port, proto, *, timeout=1): def _send(socket, data): total_sent = 0 while total_sent < len(data): - sent = socket.send(data[total_sent:]) + # ESP32SPI sockets raise a RuntimeError when unable to send. + try: + sent = socket.send(data[total_sent:]) + except RuntimeError: + sent = 0 if sent is None: sent = len(data) if sent == 0: - raise RuntimeError("Connection closed") + raise _SendFailed() total_sent += sent def _send_request(self, socket, host, method, path, headers, data, json): @@ -532,12 +550,19 @@ def request( self._last_response.close() self._last_response = None - socket = self._get_socket(host, port, proto, timeout=timeout) - try: - self._send_request(socket, host, method, path, headers, data, json) - except: - self._close_socket(socket) - raise + # We may fail to send the request if the socket we got is closed already. So, try a second + # time in that case. + retry_count = 0 + while retry_count < 2: + retry_count += 1 + socket = self._get_socket(host, port, proto, timeout=timeout) + try: + self._send_request(socket, host, method, path, headers, data, json) + break + except _SendFailed: + self._close_socket(socket) + if retry_count > 1: + raise resp = Response(socket, self) # our response if "location" in resp.headers and 300 <= resp.status_code <= 399: @@ -588,11 +613,9 @@ def __init__(self, socket, tls_mode): def connect(self, address): """connect wrapper to add non-standard mode parameter""" try: - self._socket.connect(address, self._mode) - return True - except RuntimeError: - return False - + return self._socket.connect(address, self._mode) + except RuntimeError as e: + raise OSError(errno.ENOMEM) class _FakeSSLContext: def __init__(self, iface): From 019d1053f33dab81320cf8d56a444f7b901542b7 Mon Sep 17 00:00:00 2001 From: Scott Shawcroft Date: Wed, 4 Nov 2020 18:30:17 -0800 Subject: [PATCH 2/3] Black --- adafruit_requests.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/adafruit_requests.py b/adafruit_requests.py index 722884e..f747b5d 100644 --- a/adafruit_requests.py +++ b/adafruit_requests.py @@ -55,6 +55,7 @@ import errno + class _RawResponse: def __init__(self, response): self._response = response @@ -73,9 +74,11 @@ def readinto(self, buf): into buf.""" return self._response._readinto(buf) # pylint: disable=protected-access + class _SendFailed(Exception): """Custom exception to abort sending a request.""" + class Response: """The response from a request, contains all the headers/content""" @@ -430,7 +433,9 @@ def _get_socket(self, host, port, proto, *, timeout=1): retry_count += 1 try: - sock = self._socket_pool.socket(addr_info[0], addr_info[1], addr_info[2]) + sock = self._socket_pool.socket( + addr_info[0], addr_info[1], addr_info[2] + ) except OSError: continue @@ -617,6 +622,7 @@ def connect(self, address): except RuntimeError as e: raise OSError(errno.ENOMEM) + class _FakeSSLContext: def __init__(self, iface): self._iface = iface From 7f2877e6f4d981967ed982b4600dbc57fd622230 Mon Sep 17 00:00:00 2001 From: Scott Shawcroft Date: Thu, 5 Nov 2020 17:16:43 -0800 Subject: [PATCH 3/3] Add unit tests --- adafruit_requests.py | 8 +-- tests/concurrent_test.py | 86 ++++++++++++++++++++++++++++ tests/legacy_mocket.py | 7 ++- tests/legacy_test.py | 120 +++++++++++++++++++++++++++++++++++++++ tests/mocket.py | 4 ++ tests/reuse_test.py | 17 +++--- 6 files changed, 229 insertions(+), 13 deletions(-) create mode 100644 tests/concurrent_test.py diff --git a/adafruit_requests.py b/adafruit_requests.py index f747b5d..483bb80 100644 --- a/adafruit_requests.py +++ b/adafruit_requests.py @@ -429,7 +429,7 @@ def _get_socket(self, host, port, proto, *, timeout=1): if any(self._socket_free.items()): self._free_sockets() else: - raise RuntimeError("Out of sockets") + raise RuntimeError("Sending request failed") retry_count += 1 try: @@ -450,7 +450,7 @@ def _get_socket(self, host, port, proto, *, timeout=1): except MemoryError: sock.close() sock = None - except OSError as e: + except OSError: sock.close() sock = None @@ -619,8 +619,8 @@ def connect(self, address): """connect wrapper to add non-standard mode parameter""" try: return self._socket.connect(address, self._mode) - except RuntimeError as e: - raise OSError(errno.ENOMEM) + except RuntimeError as error: + raise OSError(errno.ENOMEM) from error class _FakeSSLContext: diff --git a/tests/concurrent_test.py b/tests/concurrent_test.py new file mode 100644 index 0000000..b857418 --- /dev/null +++ b/tests/concurrent_test.py @@ -0,0 +1,86 @@ +from unittest import mock +import mocket +import pytest +import errno +import adafruit_requests + +ip = "1.2.3.4" +host = "wifitest.adafruit.com" +host2 = "wifitest2.adafruit.com" +path = "/testwifi/index.html" +text = b"This is a test of Adafruit WiFi!\r\nIf you can read this, its working :)" +response = b"HTTP/1.0 200 OK\r\nContent-Length: 70\r\n\r\n" + text + + +def test_second_connect_fails_memoryerror(): + pool = mocket.MocketPool() + pool.getaddrinfo.return_value = ((None, None, None, None, (ip, 80)),) + sock = mocket.Mocket(response) + sock2 = mocket.Mocket(response) + sock3 = mocket.Mocket(response) + pool.socket.call_count = 0 # Reset call count + pool.socket.side_effect = [sock, sock2, sock3] + sock2.connect.side_effect = MemoryError() + + ssl = mocket.SSLContext() + + s = adafruit_requests.Session(pool, ssl) + r = s.get("https://" + host + path) + + sock.send.assert_has_calls( + [mock.call(b"testwifi/index.html"),] + ) + + sock.send.assert_has_calls( + [mock.call(b"Host: "), mock.call(b"wifitest.adafruit.com"), mock.call(b"\r\n"),] + ) + assert r.text == str(text, "utf-8") + + host2 = "test.adafruit.com" + s.get("https://" + host2 + path) + + sock.connect.assert_called_once_with((host, 443)) + sock2.connect.assert_called_once_with((host2, 443)) + sock3.connect.assert_called_once_with((host2, 443)) + # Make sure that the socket is closed after send fails. + sock.close.assert_called_once() + sock2.close.assert_called_once() + assert sock3.close.call_count == 0 + assert pool.socket.call_count == 3 + + +def test_second_connect_fails_oserror(): + pool = mocket.MocketPool() + pool.getaddrinfo.return_value = ((None, None, None, None, (ip, 80)),) + sock = mocket.Mocket(response) + sock2 = mocket.Mocket(response) + sock3 = mocket.Mocket(response) + pool.socket.call_count = 0 # Reset call count + pool.socket.side_effect = [sock, sock2, sock3] + sock2.connect.side_effect = OSError(errno.ENOMEM) + + ssl = mocket.SSLContext() + + s = adafruit_requests.Session(pool, ssl) + r = s.get("https://" + host + path) + + sock.send.assert_has_calls( + [mock.call(b"testwifi/index.html"),] + ) + + sock.send.assert_has_calls( + [mock.call(b"Host: "), mock.call(b"wifitest.adafruit.com"), mock.call(b"\r\n"),] + ) + assert r.text == str(text, "utf-8") + + host2 = "test.adafruit.com" + s.get("https://" + host2 + path) + + sock.connect.assert_called_once_with((host, 443)) + sock2.connect.assert_called_once_with((host2, 443)) + sock3.connect.assert_called_once_with((host2, 443)) + # Make sure that the socket is closed after send fails. + sock.close.assert_called_once() + sock2.close.assert_called_once() + assert sock3.close.call_count == 0 + assert pool.socket.call_count == 3 diff --git a/tests/legacy_mocket.py b/tests/legacy_mocket.py index d256efc..40d1045 100644 --- a/tests/legacy_mocket.py +++ b/tests/legacy_mocket.py @@ -16,11 +16,15 @@ def __init__(self, response): self.send = mock.Mock(side_effect=self._send) self.readline = mock.Mock(side_effect=self._readline) self.recv = mock.Mock(side_effect=self._recv) + self.fail_next_send = False self._response = response self._position = 0 def _send(self, data): - return len(data) + if self.fail_next_send: + self.fail_next_send = False + raise RuntimeError("Send failed") + return None def _readline(self): i = self._response.find(b"\r\n", self._position) @@ -32,4 +36,5 @@ def _recv(self, count): end = self._position + count r = self._response[self._position : end] self._position = end + print(r) return r diff --git a/tests/legacy_test.py b/tests/legacy_test.py index 3d9cdbb..37e4ce2 100644 --- a/tests/legacy_test.py +++ b/tests/legacy_test.py @@ -1,6 +1,7 @@ from unittest import mock import legacy_mocket as mocket import json +import pytest import adafruit_requests ip = "1.2.3.4" @@ -49,3 +50,122 @@ def test_post_string(): sock.connect.assert_called_once_with((ip, 80)) sock.send.assert_called_with(b"31F") r.close() + + +def test_second_tls_send_fails(): + mocket.getaddrinfo.return_value = ((None, None, None, None, (ip, 80)),) + sock = mocket.Mocket(headers + encoded) + sock2 = mocket.Mocket(headers + encoded) + mocket.socket.call_count = 0 # Reset call count + mocket.socket.side_effect = [sock, sock2] + + adafruit_requests.set_socket(mocket, mocket.interface) + r = adafruit_requests.get("https://" + host + "/testwifi/index.html") + + sock.send.assert_has_calls( + [mock.call(b"testwifi/index.html"),] + ) + + sock.send.assert_has_calls( + [mock.call(b"Host: "), mock.call(host.encode("utf-8")), mock.call(b"\r\n"),] + ) + assert r.text == str(encoded, "utf-8") + + sock.fail_next_send = True + adafruit_requests.get("https://" + host + "/get2") + + sock.connect.assert_called_once_with((host, 443), mocket.interface.TLS_MODE) + sock2.connect.assert_called_once_with((host, 443), mocket.interface.TLS_MODE) + # Make sure that the socket is closed after send fails. + sock.close.assert_called_once() + assert sock2.close.call_count == 0 + assert mocket.socket.call_count == 2 + + +def test_second_send_fails(): + mocket.getaddrinfo.return_value = ((None, None, None, None, (ip, 80)),) + sock = mocket.Mocket(headers + encoded) + sock2 = mocket.Mocket(headers + encoded) + mocket.socket.call_count = 0 # Reset call count + mocket.socket.side_effect = [sock, sock2] + + adafruit_requests.set_socket(mocket, mocket.interface) + r = adafruit_requests.get("http://" + host + "/testwifi/index.html") + + sock.send.assert_has_calls( + [mock.call(b"testwifi/index.html"),] + ) + + sock.send.assert_has_calls( + [mock.call(b"Host: "), mock.call(host.encode("utf-8")), mock.call(b"\r\n"),] + ) + assert r.text == str(encoded, "utf-8") + + sock.fail_next_send = True + adafruit_requests.get("http://" + host + "/get2") + + sock.connect.assert_called_once_with((ip, 80)) + sock2.connect.assert_called_once_with((ip, 80)) + # Make sure that the socket is closed after send fails. + sock.close.assert_called_once() + assert sock2.close.call_count == 0 + assert mocket.socket.call_count == 2 + + +def test_first_read_fails(): + mocket.getaddrinfo.return_value = ((None, None, None, None, (ip, 80)),) + sock = mocket.Mocket(b"") + mocket.socket.call_count = 0 # Reset call count + mocket.socket.side_effect = [sock] + + adafruit_requests.set_socket(mocket, mocket.interface) + + with pytest.raises(RuntimeError): + r = adafruit_requests.get("http://" + host + "/testwifi/index.html") + + sock.send.assert_has_calls( + [mock.call(b"testwifi/index.html"),] + ) + + sock.send.assert_has_calls( + [mock.call(b"Host: "), mock.call(host.encode("utf-8")), mock.call(b"\r\n"),] + ) + + sock.connect.assert_called_once_with((ip, 80)) + # Make sure that the socket is closed after the first receive fails. + sock.close.assert_called_once() + assert mocket.socket.call_count == 1 + + +def test_second_tls_connect_fails(): + mocket.getaddrinfo.return_value = ((None, None, None, None, (ip, 80)),) + sock = mocket.Mocket(headers + encoded) + sock2 = mocket.Mocket(headers + encoded) + sock3 = mocket.Mocket(headers + encoded) + mocket.socket.call_count = 0 # Reset call count + mocket.socket.side_effect = [sock, sock2, sock3] + sock2.connect.side_effect = RuntimeError("error connecting") + + adafruit_requests.set_socket(mocket, mocket.interface) + r = adafruit_requests.get("https://" + host + "/testwifi/index.html") + + sock.send.assert_has_calls( + [mock.call(b"testwifi/index.html"),] + ) + + sock.send.assert_has_calls( + [mock.call(b"Host: "), mock.call(host.encode("utf-8")), mock.call(b"\r\n"),] + ) + assert r.text == str(encoded, "utf-8") + + host2 = "test.adafruit.com" + r = adafruit_requests.get("https://" + host2 + "/get2") + + sock.connect.assert_called_once_with((host, 443), mocket.interface.TLS_MODE) + sock2.connect.assert_called_once_with((host2, 443), mocket.interface.TLS_MODE) + sock3.connect.assert_called_once_with((host2, 443), mocket.interface.TLS_MODE) + # Make sure that the socket is closed after send fails. + sock.close.assert_called_once() + sock2.close.assert_called_once() + assert sock3.close.call_count == 0 + assert mocket.socket.call_count == 3 diff --git a/tests/mocket.py b/tests/mocket.py index 916bbef..bc24daf 100644 --- a/tests/mocket.py +++ b/tests/mocket.py @@ -20,8 +20,12 @@ def __init__(self, response): self.recv_into = mock.Mock(side_effect=self._recv_into) self._response = response self._position = 0 + self.fail_next_send = False def _send(self, data): + if self.fail_next_send: + self.fail_next_send = False + return 0 return len(data) def _readline(self): diff --git a/tests/reuse_test.py b/tests/reuse_test.py index 6e33e7e..21b224c 100644 --- a/tests/reuse_test.py +++ b/tests/reuse_test.py @@ -143,8 +143,10 @@ def test_connect_out_of_memory(): def test_second_send_fails(): pool = mocket.MocketPool() pool.getaddrinfo.return_value = ((None, None, None, None, (ip, 80)),) - sock = mocket.Mocket(response + response) - pool.socket.return_value = sock + sock = mocket.Mocket(response) + sock2 = mocket.Mocket(response) + pool.socket.side_effect = [sock, sock2] + ssl = mocket.SSLContext() s = adafruit_requests.Session(pool, ssl) @@ -159,13 +161,12 @@ def test_second_send_fails(): ) assert r.text == str(text, "utf-8") - sock.send.side_effect = None - sock.send.return_value = 0 - - with pytest.raises(RuntimeError): - s.get("https://" + host + path + "2") + sock.fail_next_send = True + s.get("https://" + host + path + "2") sock.connect.assert_called_once_with((host, 443)) + sock2.connect.assert_called_once_with((host, 443)) # Make sure that the socket is closed after send fails. sock.close.assert_called_once() - pool.socket.assert_called_once() + assert sock2.close.call_count == 0 + assert pool.socket.call_count == 2