diff --git a/adafruit_requests.py b/adafruit_requests.py index 8d157f1..2d4d578 100755 --- a/adafruit_requests.py +++ b/adafruit_requests.py @@ -377,26 +377,28 @@ def __init__(self, socket_pool, ssl_context=None): self._last_response = None def _free_socket(self, socket): - if socket not in self._open_sockets.values(): raise RuntimeError("Socket not from session") self._socket_free[socket] = True + def _close_socket(self, sock): + sock.close() + del self._socket_free[sock] + key = None + for k in self._open_sockets: + if self._open_sockets[k] == sock: + key = k + break + if key: + del self._open_sockets[key] + def _free_sockets(self): free_sockets = [] for sock in self._socket_free: if self._socket_free[sock]: - sock.close() free_sockets.append(sock) for sock in free_sockets: - del self._socket_free[sock] - key = None - for k in self._open_sockets: - if self._open_sockets[k] == sock: - key = k - break - if key: - del self._open_sockets[key] + self._close_socket(sock) def _get_socket(self, host, port, proto, *, timeout=1): key = (host, port, proto) @@ -440,6 +442,61 @@ def _get_socket(self, host, port, proto, *, timeout=1): self._socket_free[sock] = False return sock + @staticmethod + def _send(socket, data): + total_sent = 0 + while total_sent < len(data): + sent = socket.send(data[total_sent:]) + if sent is None: + sent = len(data) + if sent == 0: + raise RuntimeError("Connection closed") + total_sent += sent + + def _send_request(self, socket, host, method, path, headers, data, json): + # pylint: disable=too-many-arguments + self._send(socket, bytes(method, "utf-8")) + self._send(socket, b" /") + self._send(socket, bytes(path, "utf-8")) + self._send(socket, b" HTTP/1.1\r\n") + if "Host" not in headers: + self._send(socket, b"Host: ") + self._send(socket, bytes(host, "utf-8")) + self._send(socket, b"\r\n") + if "User-Agent" not in headers: + self._send(socket, b"User-Agent: Adafruit CircuitPython\r\n") + # Iterate over keys to avoid tuple alloc + for k in headers: + self._send(socket, k.encode()) + self._send(socket, b": ") + self._send(socket, headers[k].encode()) + self._send(socket, b"\r\n") + if json is not None: + assert data is None + # pylint: disable=import-outside-toplevel + try: + import json as json_module + except ImportError: + import ujson as json_module + data = json_module.dumps(json) + self._send(socket, b"Content-Type: application/json\r\n") + if data: + if isinstance(data, dict): + self._send( + socket, b"Content-Type: application/x-www-form-urlencoded\r\n" + ) + _post_data = "" + for k in data: + _post_data = "{}&{}={}".format(_post_data, k, data[k]) + data = _post_data[1:] + self._send(socket, b"Content-Length: %d\r\n" % len(data)) + self._send(socket, b"\r\n") + if data: + if isinstance(data, bytearray): + self._send(socket, bytes(data)) + else: + self._send(socket, bytes(data, "utf-8")) + # pylint: disable=too-many-branches, too-many-statements, unused-argument, too-many-arguments, too-many-locals def request( self, method, url, data=None, json=None, headers=None, stream=False, timeout=60 @@ -476,42 +533,11 @@ def request( self._last_response = None socket = self._get_socket(host, port, proto, timeout=timeout) - socket.send( - b"%s /%s HTTP/1.1\r\n" % (bytes(method, "utf-8"), bytes(path, "utf-8")) - ) - if "Host" not in headers: - socket.send(b"Host: %s\r\n" % bytes(host, "utf-8")) - if "User-Agent" not in headers: - socket.send(b"User-Agent: Adafruit CircuitPython\r\n") - # Iterate over keys to avoid tuple alloc - for k in headers: - socket.send(k.encode()) - socket.send(b": ") - socket.send(headers[k].encode()) - socket.send(b"\r\n") - if json is not None: - assert data is None - # pylint: disable=import-outside-toplevel - try: - import json as json_module - except ImportError: - import ujson as json_module - data = json_module.dumps(json) - socket.send(b"Content-Type: application/json\r\n") - if data: - if isinstance(data, dict): - socket.send(b"Content-Type: application/x-www-form-urlencoded\r\n") - _post_data = "" - for k in data: - _post_data = "{}&{}={}".format(_post_data, k, data[k]) - data = _post_data[1:] - socket.send(b"Content-Length: %d\r\n" % len(data)) - socket.send(b"\r\n") - if data: - if isinstance(data, bytearray): - socket.send(bytes(data)) - else: - socket.send(bytes(data, "utf-8")) + try: + self._send_request(socket, host, method, path, headers, data, json) + except: + self._close_socket(socket) + raise resp = Response(socket, self) # our response if "location" in resp.headers and 300 <= resp.status_code <= 399: @@ -557,6 +583,7 @@ def __init__(self, socket, tls_mode): self.settimeout = socket.settimeout self.send = socket.send self.recv = socket.recv + self.close = socket.close def connect(self, address): """connect wrapper to add non-standard mode parameter""" diff --git a/tests/chunk_test.py b/tests/chunk_test.py index 67f09ec..585cebd 100644 --- a/tests/chunk_test.py +++ b/tests/chunk_test.py @@ -39,10 +39,16 @@ def test_get_text(): r = s.get("http://" + host + path) sock.connect.assert_called_once_with((ip, 80)) + sock.send.assert_has_calls( [ - mock.call(b"GET /testwifi/index.html HTTP/1.1\r\n"), - mock.call(b"Host: wifitest.adafruit.com\r\n"), + mock.call(b"GET"), + mock.call(b" /"), + mock.call(b"testwifi/index.html"), + mock.call(b" HTTP/1.1\r\n"), ] ) + sock.send.assert_has_calls( + [mock.call(b"Host: "), mock.call(b"wifitest.adafruit.com"),] + ) assert r.text == str(text, "utf-8") diff --git a/tests/header_test.py b/tests/header_test.py index 2c9db60..46dcc76 100644 --- a/tests/header_test.py +++ b/tests/header_test.py @@ -14,7 +14,12 @@ def test_json(): sock = mocket.Mocket(response_headers) pool.socket.return_value = sock sent = [] - sock.send.side_effect = sent.append + + def _send(data): + sent.append(data) + return len(data) + + sock.send.side_effect = _send s = adafruit_requests.Session(pool) headers = {"user-agent": "blinka/1.0.0"} diff --git a/tests/legacy_mocket.py b/tests/legacy_mocket.py index 4a37bd2..d256efc 100644 --- a/tests/legacy_mocket.py +++ b/tests/legacy_mocket.py @@ -13,12 +13,15 @@ def __init__(self, response): self.settimeout = mock.Mock() self.close = mock.Mock() self.connect = mock.Mock() - self.send = mock.Mock() + self.send = mock.Mock(side_effect=self._send) self.readline = mock.Mock(side_effect=self._readline) self.recv = mock.Mock(side_effect=self._recv) self._response = response self._position = 0 + def _send(self, data): + return len(data) + def _readline(self): i = self._response.find(b"\r\n", self._position) r = self._response[self._position : i + 2] diff --git a/tests/mocket.py b/tests/mocket.py index ec9a557..916bbef 100644 --- a/tests/mocket.py +++ b/tests/mocket.py @@ -14,13 +14,16 @@ def __init__(self, response): self.settimeout = mock.Mock() self.close = mock.Mock() self.connect = mock.Mock() - self.send = mock.Mock() + self.send = mock.Mock(side_effect=self._send) self.readline = mock.Mock(side_effect=self._readline) self.recv = mock.Mock(side_effect=self._recv) self.recv_into = mock.Mock(side_effect=self._recv_into) self._response = response self._position = 0 + def _send(self, data): + return len(data) + def _readline(self): i = self._response.find(b"\r\n", self._position) r = self._response[self._position : i + 2] diff --git a/tests/post_test.py b/tests/post_test.py index c8660a2..c2a9b7e 100644 --- a/tests/post_test.py +++ b/tests/post_test.py @@ -21,8 +21,17 @@ def test_method(): s = adafruit_requests.Session(pool) r = s.post("http://" + host + "/post") sock.connect.assert_called_once_with((ip, 80)) + + sock.send.assert_has_calls( + [ + mock.call(b"POST"), + mock.call(b" /"), + mock.call(b"post"), + mock.call(b" HTTP/1.1\r\n"), + ] + ) sock.send.assert_has_calls( - [mock.call(b"POST /post HTTP/1.1\r\n"), mock.call(b"Host: httpbin.org\r\n")] + [mock.call(b"Host: "), mock.call(b"httpbin.org"),] ) diff --git a/tests/protocol_test.py b/tests/protocol_test.py index c7ad9da..c00c485 100644 --- a/tests/protocol_test.py +++ b/tests/protocol_test.py @@ -32,12 +32,18 @@ def test_get_https_text(): r = s.get("https://" + host + path) sock.connect.assert_called_once_with((host, 443)) + sock.send.assert_has_calls( [ - mock.call(b"GET /testwifi/index.html HTTP/1.1\r\n"), - mock.call(b"Host: wifitest.adafruit.com\r\n"), + mock.call(b"GET"), + mock.call(b" /"), + mock.call(b"testwifi/index.html"), + mock.call(b" HTTP/1.1\r\n"), ] ) + sock.send.assert_has_calls( + [mock.call(b"Host: "), mock.call(b"wifitest.adafruit.com"),] + ) assert r.text == str(text, "utf-8") # Close isn't needed but can be called to release the socket early. @@ -54,10 +60,16 @@ def test_get_http_text(): r = s.get("http://" + host + path) sock.connect.assert_called_once_with((ip, 80)) + sock.send.assert_has_calls( [ - mock.call(b"GET /testwifi/index.html HTTP/1.1\r\n"), - mock.call(b"Host: wifitest.adafruit.com\r\n"), + mock.call(b"GET"), + mock.call(b" /"), + mock.call(b"testwifi/index.html"), + mock.call(b" HTTP/1.1\r\n"), ] ) + sock.send.assert_has_calls( + [mock.call(b"Host: "), mock.call(b"wifitest.adafruit.com"),] + ) assert r.text == str(text, "utf-8") diff --git a/tests/reuse_test.py b/tests/reuse_test.py index 2e06931..6e33e7e 100644 --- a/tests/reuse_test.py +++ b/tests/reuse_test.py @@ -10,35 +10,47 @@ text = b"This is a test of Adafruit WiFi!\r\nIf you can read this, its working :)" response = b"HTTP/1.0 200 OK\r\nContent-Length: 70\r\n\r\n" + text -# def test_get_twice(): -# pool = mocket.MocketPool() -# pool.getaddrinfo.return_value = ((None, None, None, None, (ip, 80)),) -# sock = mocket.Mocket(response + response) -# pool.socket.return_value = sock -# ssl = mocket.SSLContext() - -# s = adafruit_requests.Session(pool, ssl) -# r = s.get("https://" + host + path) - -# sock.send.assert_has_calls( -# [ -# mock.call(b"GET /testwifi/index.html HTTP/1.1\r\n"), -# mock.call(b"Host: wifitest.adafruit.com\r\n"), -# ] -# ) -# assert r.text == str(text, "utf-8") - -# r = s.get("https://" + host + path + "2") -# sock.send.assert_has_calls( -# [ -# mock.call(b"GET /testwifi/index.html2 HTTP/1.1\r\n"), -# mock.call(b"Host: wifitest.adafruit.com\r\n"), -# ] -# ) - -# assert r.text == str(text, "utf-8") -# sock.connect.assert_called_once_with((host, 443)) -# pool.socket.assert_called_once() + +def test_get_twice(): + pool = mocket.MocketPool() + pool.getaddrinfo.return_value = ((None, None, None, None, (ip, 80)),) + sock = mocket.Mocket(response + response) + pool.socket.return_value = sock + ssl = mocket.SSLContext() + + s = adafruit_requests.Session(pool, ssl) + r = s.get("https://" + host + path) + + sock.send.assert_has_calls( + [ + mock.call(b"GET"), + mock.call(b" /"), + mock.call(b"testwifi/index.html"), + mock.call(b" HTTP/1.1\r\n"), + ] + ) + sock.send.assert_has_calls( + [mock.call(b"Host: "), mock.call(b"wifitest.adafruit.com"),] + ) + assert r.text == str(text, "utf-8") + + r = s.get("https://" + host + path + "2") + + sock.send.assert_has_calls( + [ + mock.call(b"GET"), + mock.call(b" /"), + mock.call(b"testwifi/index.html2"), + mock.call(b" HTTP/1.1\r\n"), + ] + ) + sock.send.assert_has_calls( + [mock.call(b"Host: "), mock.call(b"wifitest.adafruit.com"),] + ) + + assert r.text == str(text, "utf-8") + sock.connect.assert_called_once_with((host, 443)) + pool.socket.assert_called_once() def test_get_twice_after_second(): @@ -53,18 +65,29 @@ def test_get_twice_after_second(): sock.send.assert_has_calls( [ - mock.call(b"GET /testwifi/index.html HTTP/1.1\r\n"), - mock.call(b"Host: wifitest.adafruit.com\r\n"), + mock.call(b"GET"), + mock.call(b" /"), + mock.call(b"testwifi/index.html"), + mock.call(b" HTTP/1.1\r\n"), ] ) + sock.send.assert_has_calls( + [mock.call(b"Host: "), mock.call(b"wifitest.adafruit.com"),] + ) r2 = s.get("https://" + host + path + "2") + sock.send.assert_has_calls( [ - mock.call(b"GET /testwifi/index.html2 HTTP/1.1\r\n"), - mock.call(b"Host: wifitest.adafruit.com\r\n"), + mock.call(b"GET"), + mock.call(b" /"), + mock.call(b"testwifi/index.html2"), + mock.call(b" HTTP/1.1\r\n"), ] ) + sock.send.assert_has_calls( + [mock.call(b"Host: "), mock.call(b"wifitest.adafruit.com"),] + ) sock.connect.assert_called_once_with((host, 443)) pool.socket.assert_called_once() @@ -87,21 +110,62 @@ def test_connect_out_of_memory(): sock.send.assert_has_calls( [ - mock.call(b"GET /testwifi/index.html HTTP/1.1\r\n"), - mock.call(b"Host: wifitest.adafruit.com\r\n"), + mock.call(b"GET"), + mock.call(b" /"), + mock.call(b"testwifi/index.html"), + mock.call(b" HTTP/1.1\r\n"), ] ) + sock.send.assert_has_calls( + [mock.call(b"Host: "), mock.call(b"wifitest.adafruit.com"),] + ) assert r.text == str(text, "utf-8") r = s.get("https://" + host2 + path) sock3.send.assert_has_calls( [ - mock.call(b"GET /testwifi/index.html HTTP/1.1\r\n"), - mock.call(b"Host: wifitest2.adafruit.com\r\n"), + mock.call(b"GET"), + mock.call(b" /"), + mock.call(b"testwifi/index.html"), + mock.call(b" HTTP/1.1\r\n"), ] ) + sock3.send.assert_has_calls( + [mock.call(b"Host: "), mock.call(b"wifitest2.adafruit.com"),] + ) assert r.text == str(text, "utf-8") sock.close.assert_called_once() sock.connect.assert_called_once_with((host, 443)) sock3.connect.assert_called_once_with((host2, 443)) + + +def test_second_send_fails(): + pool = mocket.MocketPool() + pool.getaddrinfo.return_value = ((None, None, None, None, (ip, 80)),) + sock = mocket.Mocket(response + response) + pool.socket.return_value = sock + ssl = mocket.SSLContext() + + s = adafruit_requests.Session(pool, ssl) + r = s.get("https://" + host + path) + + sock.send.assert_has_calls( + [mock.call(b"testwifi/index.html"),] + ) + + sock.send.assert_has_calls( + [mock.call(b"Host: "), mock.call(b"wifitest.adafruit.com"), mock.call(b"\r\n"),] + ) + assert r.text == str(text, "utf-8") + + sock.send.side_effect = None + sock.send.return_value = 0 + + with pytest.raises(RuntimeError): + s.get("https://" + host + path + "2") + + sock.connect.assert_called_once_with((host, 443)) + # Make sure that the socket is closed after send fails. + sock.close.assert_called_once() + pool.socket.assert_called_once()