diff --git a/adafruit_requests.py b/adafruit_requests.py index aaa56cb..f59b2f4 100644 --- a/adafruit_requests.py +++ b/adafruit_requests.py @@ -37,11 +37,40 @@ __repo__ = "https://github.com/adafruit/Adafruit_CircuitPython_Requests.git" import errno +from types import TracebackType + +try: + from typing import Union, TypeVar, Optional, Dict, Any, List, Type + import types + import ssl + import adafruit_esp32spi.adafruit_esp32spi_socket as esp32_socket + import adafruit_wiznet5k.adafruit_wiznet5k_socket as wiznet_socket + import adafruit_fona.adafruit_fona_socket as cellular_socket + from adafruit_esp32spi.adafruit_esp32spi import ESP_SPIcontrol + from adafruit_wiznet5k.adafruit_wiznet5k import WIZNET5K + from adafruit_fona.adafruit_fona import FONA + import socket as cpython_socket + + SocketType = TypeVar( + "SocketType", + esp32_socket.socket, + wiznet_socket.socket, + cellular_socket.socket, + cpython_socket.socket, + ) + SocketpoolModuleType = types.ModuleType + SSLContextType = ( + ssl.SSLContext + ) # Can use either CircuitPython or CPython ssl module + InterfaceType = TypeVar("InterfaceType", ESP_SPIcontrol, WIZNET5K, FONA) + +except ImportError: + pass # CircuitPython 6.0 does not have the bytearray.split method. # This function emulates buf.split(needle)[0], which is the functionality # required. -def _buffer_split0(buf, needle): +def _buffer_split0(buf: Union[bytes, bytearray], needle: Union[bytes, bytearray]): index = buf.find(needle) if index == -1: return buf @@ -49,10 +78,10 @@ def _buffer_split0(buf, needle): class _RawResponse: - def __init__(self, response): + def __init__(self, response: "Response") -> None: self._response = response - def read(self, size=-1): + def read(self, size: int = -1) -> bytes: """Read as much as available or up to size and return it in a byte string. Do NOT use this unless you really need to. Reusing memory with `readinto` is much better. @@ -61,7 +90,7 @@ def read(self, size=-1): return self._response.content return self._response.socket.recv(size) - def readinto(self, buf): + def readinto(self, buf: bytearray) -> int: """Read as much as available into buf or until it is full. Returns the number of bytes read into buf.""" return self._response._readinto(buf) # pylint: disable=protected-access @@ -82,7 +111,7 @@ class Response: encoding = None - def __init__(self, sock, session=None): + def __init__(self, sock: SocketType, session: Optional["Session"] = None) -> None: self.socket = sock self.encoding = "utf-8" self._cached = None @@ -110,13 +139,18 @@ def __init__(self, sock, session=None): self._raw = None self._session = session - def __enter__(self): + def __enter__(self) -> "Response": return self - def __exit__(self, exc_type, exc_value, traceback): + def __exit__( + self, + exc_type: Optional[Type[type]], + exc_value: Optional[BaseException], + traceback: Optional[TracebackType], + ) -> None: self.close() - def _recv_into(self, buf, size=0): + def _recv_into(self, buf: bytearray, size: int = 0) -> int: if self._backwards_compatible: size = len(buf) if size == 0 else size b = self.socket.recv(size) @@ -126,7 +160,7 @@ def _recv_into(self, buf, size=0): return self.socket.recv_into(buf, size) @staticmethod - def _find(buf, needle, start, end): + def _find(buf: bytes, needle: bytes, start: int, end: int) -> int: if hasattr(buf, "find"): return buf.find(needle, start, end) result = -1 @@ -142,7 +176,7 @@ def _find(buf, needle, start, end): return result - def _readto(self, first, second=b""): + def _readto(self, first: bytes, second: bytes = b"") -> bytes: buf = self._receive_buffer end = self._received_length while True: @@ -187,7 +221,9 @@ def _readto(self, first, second=b""): return b"" - def _read_from_buffer(self, buf=None, nbytes=None): + def _read_from_buffer( + self, buf: Optional[bytearray] = None, nbytes: Optional[int] = None + ) -> int: if self._received_length == 0: return 0 read = self._received_length @@ -204,7 +240,7 @@ def _read_from_buffer(self, buf=None, nbytes=None): self._received_length = 0 return read - def _readinto(self, buf): + def _readinto(self, buf: bytearray) -> int: if not self.socket: raise RuntimeError( "Newer Response closed this one. Use Responses immediately." @@ -237,7 +273,7 @@ def _readinto(self, buf): return read - def _throw_away(self, nbytes): + def _throw_away(self, nbytes: int) -> None: nbytes -= self._read_from_buffer(nbytes=nbytes) buf = self._receive_buffer @@ -247,7 +283,7 @@ def _throw_away(self, nbytes): if remaining: self._recv_into(buf, remaining) - def close(self): + def close(self) -> None: """Drain the remaining ESP socket buffers. We assume we already got what we wanted.""" if not self.socket: return @@ -269,7 +305,7 @@ def close(self): self.socket.close() self.socket = None - def _parse_headers(self): + def _parse_headers(self) -> None: """ Parses the header portion of an HTTP request/response from the socket. Expects first line of HTTP request/response to have been read already. @@ -291,7 +327,7 @@ def _parse_headers(self): self._headers[title] = content @property - def headers(self): + def headers(self) -> Dict[str, str]: """ The response headers. Does not include headers from the trailer until the content has been read. @@ -299,7 +335,7 @@ def headers(self): return self._headers @property - def content(self): + def content(self) -> bytes: """The HTTP content direct from the socket, as bytes""" if self._cached is not None: if isinstance(self._cached, bytes): @@ -310,7 +346,7 @@ def content(self): return self._cached @property - def text(self): + def text(self) -> str: """The HTTP content, encoded into a string according to the HTTP header encoding""" if self._cached is not None: @@ -320,7 +356,7 @@ def text(self): self._cached = str(self.content, self.encoding) return self._cached - def json(self): + def json(self) -> Any: """The HTTP content, parsed into a json dictionary""" # pylint: disable=import-outside-toplevel import json @@ -344,7 +380,7 @@ def json(self): self.close() return obj - def iter_content(self, chunk_size=1, decode_unicode=False): + def iter_content(self, chunk_size: int = 1, decode_unicode: bool = False) -> bytes: """An iterator that will stream data by only reading 'chunk_size' bytes and yielding them, when we can't buffer the whole datastream""" if decode_unicode: @@ -366,7 +402,11 @@ def iter_content(self, chunk_size=1, decode_unicode=False): class Session: """HTTP session that shares sockets and ssl context.""" - def __init__(self, socket_pool, ssl_context=None): + def __init__( + self, + socket_pool: SocketpoolModuleType, + ssl_context: Optional[SSLContextType] = None, + ) -> None: self._socket_pool = socket_pool self._ssl_context = ssl_context # Hang onto open sockets so that we can reuse them. @@ -374,12 +414,12 @@ def __init__(self, socket_pool, ssl_context=None): self._socket_free = {} self._last_response = None - def _free_socket(self, socket): + def _free_socket(self, socket: SocketType) -> None: if socket not in self._open_sockets.values(): raise RuntimeError("Socket not from session") self._socket_free[socket] = True - def _close_socket(self, sock): + def _close_socket(self, sock: SocketType) -> None: sock.close() del self._socket_free[sock] key = None @@ -390,7 +430,7 @@ def _close_socket(self, sock): if key: del self._open_sockets[key] - def _free_sockets(self): + def _free_sockets(self) -> None: free_sockets = [] for sock, val in self._socket_free.items(): if val: @@ -398,7 +438,9 @@ def _free_sockets(self): for sock in free_sockets: self._close_socket(sock) - def _get_socket(self, host, port, proto, *, timeout=1): + def _get_socket( + self, host: str, port: int, proto: str, *, timeout: float = 1 + ) -> SocketType: # pylint: disable=too-many-branches key = (host, port, proto) if key in self._open_sockets: @@ -453,7 +495,7 @@ def _get_socket(self, host, port, proto, *, timeout=1): return sock @staticmethod - def _send(socket, data): + def _send(socket: SocketType, data: bytes): total_sent = 0 while total_sent < len(data): # ESP32SPI sockets raise a RuntimeError when unable to send. @@ -467,7 +509,16 @@ def _send(socket, data): raise _SendFailed() total_sent += sent - def _send_request(self, socket, host, method, path, headers, data, json): + def _send_request( + self, + socket: SocketType, + host: str, + method: str, + path: str, + headers: List[Dict[str, str]], + data: Any, + json: Any, + ): # pylint: disable=too-many-arguments self._send(socket, bytes(method, "utf-8")) self._send(socket, b" /") @@ -513,8 +564,15 @@ def _send_request(self, socket, host, method, path, headers, data, json): # pylint: disable=too-many-branches, too-many-statements, unused-argument, too-many-arguments, too-many-locals def request( - self, method, url, data=None, json=None, headers=None, stream=False, timeout=60 - ): + self, + method: str, + url: str, + data: Optional[Any] = None, + json: Optional[Any] = None, + headers: Optional[List[Dict[str, str]]] = None, + stream: bool = False, + timeout: float = 60, + ) -> Response: """Perform an HTTP request to the given url which we will parse to determine whether to use SSL ('https://') or not. We can also send some provided 'data' or a json dictionary which we will stringify. 'headers' is optional HTTP headers @@ -604,27 +662,27 @@ def request( self._last_response = resp return resp - def head(self, url, **kw): + def head(self, url: str, **kw) -> Response: """Send HTTP HEAD request""" return self.request("HEAD", url, **kw) - def get(self, url, **kw): + def get(self, url: str, **kw) -> Response: """Send HTTP GET request""" return self.request("GET", url, **kw) - def post(self, url, **kw): + def post(self, url: str, **kw) -> Response: """Send HTTP POST request""" return self.request("POST", url, **kw) - def put(self, url, **kw): + def put(self, url: str, **kw) -> Response: """Send HTTP PUT request""" return self.request("PUT", url, **kw) - def patch(self, url, **kw): + def patch(self, url: str, **kw) -> Response: """Send HTTP PATCH request""" return self.request("PATCH", url, **kw) - def delete(self, url, **kw): + def delete(self, url: str, **kw) -> Response: """Send HTTP DELETE request""" return self.request("DELETE", url, **kw) @@ -635,7 +693,7 @@ def delete(self, url, **kw): class _FakeSSLSocket: - def __init__(self, socket, tls_mode): + def __init__(self, socket: SocketType, tls_mode: int) -> None: self._socket = socket self._mode = tls_mode self.settimeout = socket.settimeout @@ -643,7 +701,7 @@ def __init__(self, socket, tls_mode): self.recv = socket.recv self.close = socket.close - def connect(self, address): + def connect(self, address: Union[bytes, str]) -> None: """connect wrapper to add non-standard mode parameter""" try: return self._socket.connect(address, self._mode) @@ -652,16 +710,20 @@ def connect(self, address): class _FakeSSLContext: - def __init__(self, iface): + def __init__(self, iface: InterfaceType) -> None: self._iface = iface - def wrap_socket(self, socket, server_hostname=None): + def wrap_socket( + self, socket: SocketType, server_hostname: Optional[str] = None + ) -> _FakeSSLSocket: """Return the same socket""" # pylint: disable=unused-argument return _FakeSSLSocket(socket, self._iface.TLS_MODE) -def set_socket(sock, iface=None): +def set_socket( + sock: SocketpoolModuleType, iface: Optional[InterfaceType] = None +) -> None: """Legacy API for setting the socket and network interface. Use a `Session` instead.""" global _default_session # pylint: disable=global-statement,invalid-name if not iface: @@ -672,7 +734,15 @@ def set_socket(sock, iface=None): sock.set_interface(iface) -def request(method, url, data=None, json=None, headers=None, stream=False, timeout=1): +def request( + method: str, + url: str, + data: Optional[Any] = None, + json: Optional[Any] = None, + headers: Optional[List[Dict[str, str]]] = None, + stream: bool = False, + timeout: float = 1, +) -> None: """Send HTTP request""" # pylint: disable=too-many-arguments _default_session.request( @@ -686,31 +756,31 @@ def request(method, url, data=None, json=None, headers=None, stream=False, timeo ) -def head(url, **kw): +def head(url: str, **kw): """Send HTTP HEAD request""" return _default_session.request("HEAD", url, **kw) -def get(url, **kw): +def get(url: str, **kw): """Send HTTP GET request""" return _default_session.request("GET", url, **kw) -def post(url, **kw): +def post(url: str, **kw): """Send HTTP POST request""" return _default_session.request("POST", url, **kw) -def put(url, **kw): +def put(url: str, **kw): """Send HTTP PUT request""" return _default_session.request("PUT", url, **kw) -def patch(url, **kw): +def patch(url: str, **kw): """Send HTTP PATCH request""" return _default_session.request("PATCH", url, **kw) -def delete(url, **kw): +def delete(url: str, **kw): """Send HTTP DELETE request""" return _default_session.request("DELETE", url, **kw) diff --git a/docs/conf.py b/docs/conf.py index c0e0389..4a485bc 100755 --- a/docs/conf.py +++ b/docs/conf.py @@ -25,7 +25,7 @@ # Uncomment the below if you use native CircuitPython modules such as # digitalio, micropython and busio. List the modules you use. Without it, the # autodoc module docs will fail to generate with a warning. -# autodoc_mock_imports = ["digitalio", "busio"] +autodoc_mock_imports = ["adafruit_esp32spi", "adafruit_wiznet5k", "adafruit_fona"] intersphinx_mapping = {