Skip to content

Commit

Permalink
Merge pull request #36 from tannewt/check_send
Browse files Browse the repository at this point in the history
Check for send to return 0
  • Loading branch information
tannewt authored Sep 24, 2020
2 parents 9aaf781 + c044eab commit 46007a6
Show file tree
Hide file tree
Showing 8 changed files with 222 additions and 93 deletions.
119 changes: 73 additions & 46 deletions adafruit_requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,26 +377,28 @@ def __init__(self, socket_pool, ssl_context=None):
self._last_response = None

def _free_socket(self, socket):

if socket not in self._open_sockets.values():
raise RuntimeError("Socket not from session")
self._socket_free[socket] = True

def _close_socket(self, sock):
sock.close()
del self._socket_free[sock]
key = None
for k in self._open_sockets:
if self._open_sockets[k] == sock:
key = k
break
if key:
del self._open_sockets[key]

def _free_sockets(self):
free_sockets = []
for sock in self._socket_free:
if self._socket_free[sock]:
sock.close()
free_sockets.append(sock)
for sock in free_sockets:
del self._socket_free[sock]
key = None
for k in self._open_sockets:
if self._open_sockets[k] == sock:
key = k
break
if key:
del self._open_sockets[key]
self._close_socket(sock)

def _get_socket(self, host, port, proto, *, timeout=1):
key = (host, port, proto)
Expand Down Expand Up @@ -440,6 +442,61 @@ def _get_socket(self, host, port, proto, *, timeout=1):
self._socket_free[sock] = False
return sock

@staticmethod
def _send(socket, data):
total_sent = 0
while total_sent < len(data):
sent = socket.send(data[total_sent:])
if sent is None:
sent = len(data)
if sent == 0:
raise RuntimeError("Connection closed")
total_sent += sent

def _send_request(self, socket, host, method, path, headers, data, json):
# pylint: disable=too-many-arguments
self._send(socket, bytes(method, "utf-8"))
self._send(socket, b" /")
self._send(socket, bytes(path, "utf-8"))
self._send(socket, b" HTTP/1.1\r\n")
if "Host" not in headers:
self._send(socket, b"Host: ")
self._send(socket, bytes(host, "utf-8"))
self._send(socket, b"\r\n")
if "User-Agent" not in headers:
self._send(socket, b"User-Agent: Adafruit CircuitPython\r\n")
# Iterate over keys to avoid tuple alloc
for k in headers:
self._send(socket, k.encode())
self._send(socket, b": ")
self._send(socket, headers[k].encode())
self._send(socket, b"\r\n")
if json is not None:
assert data is None
# pylint: disable=import-outside-toplevel
try:
import json as json_module
except ImportError:
import ujson as json_module
data = json_module.dumps(json)
self._send(socket, b"Content-Type: application/json\r\n")
if data:
if isinstance(data, dict):
self._send(
socket, b"Content-Type: application/x-www-form-urlencoded\r\n"
)
_post_data = ""
for k in data:
_post_data = "{}&{}={}".format(_post_data, k, data[k])
data = _post_data[1:]
self._send(socket, b"Content-Length: %d\r\n" % len(data))
self._send(socket, b"\r\n")
if data:
if isinstance(data, bytearray):
self._send(socket, bytes(data))
else:
self._send(socket, bytes(data, "utf-8"))

# pylint: disable=too-many-branches, too-many-statements, unused-argument, too-many-arguments, too-many-locals
def request(
self, method, url, data=None, json=None, headers=None, stream=False, timeout=60
Expand Down Expand Up @@ -476,42 +533,11 @@ def request(
self._last_response = None

socket = self._get_socket(host, port, proto, timeout=timeout)
socket.send(
b"%s /%s HTTP/1.1\r\n" % (bytes(method, "utf-8"), bytes(path, "utf-8"))
)
if "Host" not in headers:
socket.send(b"Host: %s\r\n" % bytes(host, "utf-8"))
if "User-Agent" not in headers:
socket.send(b"User-Agent: Adafruit CircuitPython\r\n")
# Iterate over keys to avoid tuple alloc
for k in headers:
socket.send(k.encode())
socket.send(b": ")
socket.send(headers[k].encode())
socket.send(b"\r\n")
if json is not None:
assert data is None
# pylint: disable=import-outside-toplevel
try:
import json as json_module
except ImportError:
import ujson as json_module
data = json_module.dumps(json)
socket.send(b"Content-Type: application/json\r\n")
if data:
if isinstance(data, dict):
socket.send(b"Content-Type: application/x-www-form-urlencoded\r\n")
_post_data = ""
for k in data:
_post_data = "{}&{}={}".format(_post_data, k, data[k])
data = _post_data[1:]
socket.send(b"Content-Length: %d\r\n" % len(data))
socket.send(b"\r\n")
if data:
if isinstance(data, bytearray):
socket.send(bytes(data))
else:
socket.send(bytes(data, "utf-8"))
try:
self._send_request(socket, host, method, path, headers, data, json)
except:
self._close_socket(socket)
raise

resp = Response(socket, self) # our response
if "location" in resp.headers and 300 <= resp.status_code <= 399:
Expand Down Expand Up @@ -557,6 +583,7 @@ def __init__(self, socket, tls_mode):
self.settimeout = socket.settimeout
self.send = socket.send
self.recv = socket.recv
self.close = socket.close

def connect(self, address):
"""connect wrapper to add non-standard mode parameter"""
Expand Down
10 changes: 8 additions & 2 deletions tests/chunk_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,16 @@ def test_get_text():
r = s.get("http://" + host + path)

sock.connect.assert_called_once_with((ip, 80))

sock.send.assert_has_calls(
[
mock.call(b"GET /testwifi/index.html HTTP/1.1\r\n"),
mock.call(b"Host: wifitest.adafruit.com\r\n"),
mock.call(b"GET"),
mock.call(b" /"),
mock.call(b"testwifi/index.html"),
mock.call(b" HTTP/1.1\r\n"),
]
)
sock.send.assert_has_calls(
[mock.call(b"Host: "), mock.call(b"wifitest.adafruit.com"),]
)
assert r.text == str(text, "utf-8")
7 changes: 6 additions & 1 deletion tests/header_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,12 @@ def test_json():
sock = mocket.Mocket(response_headers)
pool.socket.return_value = sock
sent = []
sock.send.side_effect = sent.append

def _send(data):
sent.append(data)
return len(data)

sock.send.side_effect = _send

s = adafruit_requests.Session(pool)
headers = {"user-agent": "blinka/1.0.0"}
Expand Down
5 changes: 4 additions & 1 deletion tests/legacy_mocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,15 @@ def __init__(self, response):
self.settimeout = mock.Mock()
self.close = mock.Mock()
self.connect = mock.Mock()
self.send = mock.Mock()
self.send = mock.Mock(side_effect=self._send)
self.readline = mock.Mock(side_effect=self._readline)
self.recv = mock.Mock(side_effect=self._recv)
self._response = response
self._position = 0

def _send(self, data):
return len(data)

def _readline(self):
i = self._response.find(b"\r\n", self._position)
r = self._response[self._position : i + 2]
Expand Down
5 changes: 4 additions & 1 deletion tests/mocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,16 @@ def __init__(self, response):
self.settimeout = mock.Mock()
self.close = mock.Mock()
self.connect = mock.Mock()
self.send = mock.Mock()
self.send = mock.Mock(side_effect=self._send)
self.readline = mock.Mock(side_effect=self._readline)
self.recv = mock.Mock(side_effect=self._recv)
self.recv_into = mock.Mock(side_effect=self._recv_into)
self._response = response
self._position = 0

def _send(self, data):
return len(data)

def _readline(self):
i = self._response.find(b"\r\n", self._position)
r = self._response[self._position : i + 2]
Expand Down
11 changes: 10 additions & 1 deletion tests/post_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,17 @@ def test_method():
s = adafruit_requests.Session(pool)
r = s.post("http://" + host + "/post")
sock.connect.assert_called_once_with((ip, 80))

sock.send.assert_has_calls(
[
mock.call(b"POST"),
mock.call(b" /"),
mock.call(b"post"),
mock.call(b" HTTP/1.1\r\n"),
]
)
sock.send.assert_has_calls(
[mock.call(b"POST /post HTTP/1.1\r\n"), mock.call(b"Host: httpbin.org\r\n")]
[mock.call(b"Host: "), mock.call(b"httpbin.org"),]
)


Expand Down
20 changes: 16 additions & 4 deletions tests/protocol_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,18 @@ def test_get_https_text():
r = s.get("https://" + host + path)

sock.connect.assert_called_once_with((host, 443))

sock.send.assert_has_calls(
[
mock.call(b"GET /testwifi/index.html HTTP/1.1\r\n"),
mock.call(b"Host: wifitest.adafruit.com\r\n"),
mock.call(b"GET"),
mock.call(b" /"),
mock.call(b"testwifi/index.html"),
mock.call(b" HTTP/1.1\r\n"),
]
)
sock.send.assert_has_calls(
[mock.call(b"Host: "), mock.call(b"wifitest.adafruit.com"),]
)
assert r.text == str(text, "utf-8")

# Close isn't needed but can be called to release the socket early.
Expand All @@ -54,10 +60,16 @@ def test_get_http_text():
r = s.get("http://" + host + path)

sock.connect.assert_called_once_with((ip, 80))

sock.send.assert_has_calls(
[
mock.call(b"GET /testwifi/index.html HTTP/1.1\r\n"),
mock.call(b"Host: wifitest.adafruit.com\r\n"),
mock.call(b"GET"),
mock.call(b" /"),
mock.call(b"testwifi/index.html"),
mock.call(b" HTTP/1.1\r\n"),
]
)
sock.send.assert_has_calls(
[mock.call(b"Host: "), mock.call(b"wifitest.adafruit.com"),]
)
assert r.text == str(text, "utf-8")
Loading

0 comments on commit 46007a6

Please sign in to comment.