Skip to content

Commit

Permalink
Merge pull request #50 from tannewt/better_error_handling
Browse files Browse the repository at this point in the history
Better handle errors by retrying
  • Loading branch information
makermelissa authored Nov 6, 2020
2 parents 8b24815 + 7f2877e commit 333b586
Show file tree
Hide file tree
Showing 6 changed files with 289 additions and 44 deletions.
99 changes: 64 additions & 35 deletions adafruit_requests.py
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@
__version__ = "0.0.0-auto.0"
__repo__ = "https://github.com/adafruit/Adafruit_CircuitPython_Requests.git"

import errno


class _RawResponse:
def __init__(self, response):
Expand All @@ -73,6 +75,10 @@ def readinto(self, buf):
return self._response._readinto(buf) # pylint: disable=protected-access


class _SendFailed(Exception):
"""Custom exception to abort sending a request."""


class Response:
"""The response from a request, contains all the headers/content"""

Expand All @@ -94,11 +100,13 @@ def __init__(self, sock, session=None):
self._chunked = False

self._backwards_compatible = not hasattr(sock, "recv_into")
if self._backwards_compatible:
print("Socket missing recv_into. Using more memory to be compatible")

http = self._readto(b" ")
if not http:
if session:
session._close_socket(self.socket)
else:
self.socket.close()
raise RuntimeError("Unable to read HTTP response.")
self.status_code = int(bytes(self._readto(b" ")))
self.reason = self._readto(b"\r\n")
Expand Down Expand Up @@ -414,30 +422,41 @@ def _get_socket(self, host, port, proto, *, timeout=1):
addr_info = self._socket_pool.getaddrinfo(
host, port, 0, self._socket_pool.SOCK_STREAM
)[0]
sock = self._socket_pool.socket(addr_info[0], addr_info[1], addr_info[2])
connect_host = addr_info[-1][0]
if proto == "https:":
sock = self._ssl_context.wrap_socket(sock, server_hostname=host)
connect_host = host
sock.settimeout(timeout) # socket read timeout
ok = True
try:
ok = sock.connect((connect_host, port))
except MemoryError:
if not any(self._socket_free.items()):
raise
ok = False

# We couldn't connect due to memory so clean up the open sockets.
if not ok:
self._free_sockets()
# Recreate the socket because the ESP-IDF won't retry the connection if it failed once.
sock = None # Clear first so the first socket can be cleaned up.
sock = self._socket_pool.socket(addr_info[0], addr_info[1], addr_info[2])
retry_count = 0
sock = None
while retry_count < 5 and sock is None:
if retry_count > 0:
if any(self._socket_free.items()):
self._free_sockets()
else:
raise RuntimeError("Sending request failed")
retry_count += 1

try:
sock = self._socket_pool.socket(
addr_info[0], addr_info[1], addr_info[2]
)
except OSError:
continue

connect_host = addr_info[-1][0]
if proto == "https:":
sock = self._ssl_context.wrap_socket(sock, server_hostname=host)
connect_host = host
sock.settimeout(timeout) # socket read timeout
sock.connect((connect_host, port))

try:
sock.connect((connect_host, port))
except MemoryError:
sock.close()
sock = None
except OSError:
sock.close()
sock = None

if sock is None:
raise RuntimeError("Repeated socket failures")

self._open_sockets[key] = sock
self._socket_free[sock] = False
return sock
Expand All @@ -446,11 +465,15 @@ def _get_socket(self, host, port, proto, *, timeout=1):
def _send(socket, data):
total_sent = 0
while total_sent < len(data):
sent = socket.send(data[total_sent:])
# ESP32SPI sockets raise a RuntimeError when unable to send.
try:
sent = socket.send(data[total_sent:])
except RuntimeError:
sent = 0
if sent is None:
sent = len(data)
if sent == 0:
raise RuntimeError("Connection closed")
raise _SendFailed()
total_sent += sent

def _send_request(self, socket, host, method, path, headers, data, json):
Expand Down Expand Up @@ -532,12 +555,19 @@ def request(
self._last_response.close()
self._last_response = None

socket = self._get_socket(host, port, proto, timeout=timeout)
try:
self._send_request(socket, host, method, path, headers, data, json)
except:
self._close_socket(socket)
raise
# We may fail to send the request if the socket we got is closed already. So, try a second
# time in that case.
retry_count = 0
while retry_count < 2:
retry_count += 1
socket = self._get_socket(host, port, proto, timeout=timeout)
try:
self._send_request(socket, host, method, path, headers, data, json)
break
except _SendFailed:
self._close_socket(socket)
if retry_count > 1:
raise

resp = Response(socket, self) # our response
if "location" in resp.headers and 300 <= resp.status_code <= 399:
Expand Down Expand Up @@ -588,10 +618,9 @@ def __init__(self, socket, tls_mode):
def connect(self, address):
"""connect wrapper to add non-standard mode parameter"""
try:
self._socket.connect(address, self._mode)
return True
except RuntimeError:
return False
return self._socket.connect(address, self._mode)
except RuntimeError as error:
raise OSError(errno.ENOMEM) from error


class _FakeSSLContext:
Expand Down
86 changes: 86 additions & 0 deletions tests/concurrent_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
from unittest import mock
import mocket
import pytest
import errno
import adafruit_requests

ip = "1.2.3.4"
host = "wifitest.adafruit.com"
host2 = "wifitest2.adafruit.com"
path = "/testwifi/index.html"
text = b"This is a test of Adafruit WiFi!\r\nIf you can read this, its working :)"
response = b"HTTP/1.0 200 OK\r\nContent-Length: 70\r\n\r\n" + text


def test_second_connect_fails_memoryerror():
pool = mocket.MocketPool()
pool.getaddrinfo.return_value = ((None, None, None, None, (ip, 80)),)
sock = mocket.Mocket(response)
sock2 = mocket.Mocket(response)
sock3 = mocket.Mocket(response)
pool.socket.call_count = 0 # Reset call count
pool.socket.side_effect = [sock, sock2, sock3]
sock2.connect.side_effect = MemoryError()

ssl = mocket.SSLContext()

s = adafruit_requests.Session(pool, ssl)
r = s.get("https://" + host + path)

sock.send.assert_has_calls(
[mock.call(b"testwifi/index.html"),]
)

sock.send.assert_has_calls(
[mock.call(b"Host: "), mock.call(b"wifitest.adafruit.com"), mock.call(b"\r\n"),]
)
assert r.text == str(text, "utf-8")

host2 = "test.adafruit.com"
s.get("https://" + host2 + path)

sock.connect.assert_called_once_with((host, 443))
sock2.connect.assert_called_once_with((host2, 443))
sock3.connect.assert_called_once_with((host2, 443))
# Make sure that the socket is closed after send fails.
sock.close.assert_called_once()
sock2.close.assert_called_once()
assert sock3.close.call_count == 0
assert pool.socket.call_count == 3


def test_second_connect_fails_oserror():
pool = mocket.MocketPool()
pool.getaddrinfo.return_value = ((None, None, None, None, (ip, 80)),)
sock = mocket.Mocket(response)
sock2 = mocket.Mocket(response)
sock3 = mocket.Mocket(response)
pool.socket.call_count = 0 # Reset call count
pool.socket.side_effect = [sock, sock2, sock3]
sock2.connect.side_effect = OSError(errno.ENOMEM)

ssl = mocket.SSLContext()

s = adafruit_requests.Session(pool, ssl)
r = s.get("https://" + host + path)

sock.send.assert_has_calls(
[mock.call(b"testwifi/index.html"),]
)

sock.send.assert_has_calls(
[mock.call(b"Host: "), mock.call(b"wifitest.adafruit.com"), mock.call(b"\r\n"),]
)
assert r.text == str(text, "utf-8")

host2 = "test.adafruit.com"
s.get("https://" + host2 + path)

sock.connect.assert_called_once_with((host, 443))
sock2.connect.assert_called_once_with((host2, 443))
sock3.connect.assert_called_once_with((host2, 443))
# Make sure that the socket is closed after send fails.
sock.close.assert_called_once()
sock2.close.assert_called_once()
assert sock3.close.call_count == 0
assert pool.socket.call_count == 3
7 changes: 6 additions & 1 deletion tests/legacy_mocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,15 @@ def __init__(self, response):
self.send = mock.Mock(side_effect=self._send)
self.readline = mock.Mock(side_effect=self._readline)
self.recv = mock.Mock(side_effect=self._recv)
self.fail_next_send = False
self._response = response
self._position = 0

def _send(self, data):
return len(data)
if self.fail_next_send:
self.fail_next_send = False
raise RuntimeError("Send failed")
return None

def _readline(self):
i = self._response.find(b"\r\n", self._position)
Expand All @@ -32,4 +36,5 @@ def _recv(self, count):
end = self._position + count
r = self._response[self._position : end]
self._position = end
print(r)
return r
120 changes: 120 additions & 0 deletions tests/legacy_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from unittest import mock
import legacy_mocket as mocket
import json
import pytest
import adafruit_requests

ip = "1.2.3.4"
Expand Down Expand Up @@ -49,3 +50,122 @@ def test_post_string():
sock.connect.assert_called_once_with((ip, 80))
sock.send.assert_called_with(b"31F")
r.close()


def test_second_tls_send_fails():
mocket.getaddrinfo.return_value = ((None, None, None, None, (ip, 80)),)
sock = mocket.Mocket(headers + encoded)
sock2 = mocket.Mocket(headers + encoded)
mocket.socket.call_count = 0 # Reset call count
mocket.socket.side_effect = [sock, sock2]

adafruit_requests.set_socket(mocket, mocket.interface)
r = adafruit_requests.get("https://" + host + "/testwifi/index.html")

sock.send.assert_has_calls(
[mock.call(b"testwifi/index.html"),]
)

sock.send.assert_has_calls(
[mock.call(b"Host: "), mock.call(host.encode("utf-8")), mock.call(b"\r\n"),]
)
assert r.text == str(encoded, "utf-8")

sock.fail_next_send = True
adafruit_requests.get("https://" + host + "/get2")

sock.connect.assert_called_once_with((host, 443), mocket.interface.TLS_MODE)
sock2.connect.assert_called_once_with((host, 443), mocket.interface.TLS_MODE)
# Make sure that the socket is closed after send fails.
sock.close.assert_called_once()
assert sock2.close.call_count == 0
assert mocket.socket.call_count == 2


def test_second_send_fails():
mocket.getaddrinfo.return_value = ((None, None, None, None, (ip, 80)),)
sock = mocket.Mocket(headers + encoded)
sock2 = mocket.Mocket(headers + encoded)
mocket.socket.call_count = 0 # Reset call count
mocket.socket.side_effect = [sock, sock2]

adafruit_requests.set_socket(mocket, mocket.interface)
r = adafruit_requests.get("http://" + host + "/testwifi/index.html")

sock.send.assert_has_calls(
[mock.call(b"testwifi/index.html"),]
)

sock.send.assert_has_calls(
[mock.call(b"Host: "), mock.call(host.encode("utf-8")), mock.call(b"\r\n"),]
)
assert r.text == str(encoded, "utf-8")

sock.fail_next_send = True
adafruit_requests.get("http://" + host + "/get2")

sock.connect.assert_called_once_with((ip, 80))
sock2.connect.assert_called_once_with((ip, 80))
# Make sure that the socket is closed after send fails.
sock.close.assert_called_once()
assert sock2.close.call_count == 0
assert mocket.socket.call_count == 2


def test_first_read_fails():
mocket.getaddrinfo.return_value = ((None, None, None, None, (ip, 80)),)
sock = mocket.Mocket(b"")
mocket.socket.call_count = 0 # Reset call count
mocket.socket.side_effect = [sock]

adafruit_requests.set_socket(mocket, mocket.interface)

with pytest.raises(RuntimeError):
r = adafruit_requests.get("http://" + host + "/testwifi/index.html")

sock.send.assert_has_calls(
[mock.call(b"testwifi/index.html"),]
)

sock.send.assert_has_calls(
[mock.call(b"Host: "), mock.call(host.encode("utf-8")), mock.call(b"\r\n"),]
)

sock.connect.assert_called_once_with((ip, 80))
# Make sure that the socket is closed after the first receive fails.
sock.close.assert_called_once()
assert mocket.socket.call_count == 1


def test_second_tls_connect_fails():
mocket.getaddrinfo.return_value = ((None, None, None, None, (ip, 80)),)
sock = mocket.Mocket(headers + encoded)
sock2 = mocket.Mocket(headers + encoded)
sock3 = mocket.Mocket(headers + encoded)
mocket.socket.call_count = 0 # Reset call count
mocket.socket.side_effect = [sock, sock2, sock3]
sock2.connect.side_effect = RuntimeError("error connecting")

adafruit_requests.set_socket(mocket, mocket.interface)
r = adafruit_requests.get("https://" + host + "/testwifi/index.html")

sock.send.assert_has_calls(
[mock.call(b"testwifi/index.html"),]
)

sock.send.assert_has_calls(
[mock.call(b"Host: "), mock.call(host.encode("utf-8")), mock.call(b"\r\n"),]
)
assert r.text == str(encoded, "utf-8")

host2 = "test.adafruit.com"
r = adafruit_requests.get("https://" + host2 + "/get2")

sock.connect.assert_called_once_with((host, 443), mocket.interface.TLS_MODE)
sock2.connect.assert_called_once_with((host2, 443), mocket.interface.TLS_MODE)
sock3.connect.assert_called_once_with((host2, 443), mocket.interface.TLS_MODE)
# Make sure that the socket is closed after send fails.
sock.close.assert_called_once()
sock2.close.assert_called_once()
assert sock3.close.call_count == 0
assert mocket.socket.call_count == 3
Loading

0 comments on commit 333b586

Please sign in to comment.