From 906e6764d871897bb175f5e878b7e59370d3b884 Mon Sep 17 00:00:00 2001 From: Justin Myers Date: Mon, 13 May 2024 14:32:32 -0700 Subject: [PATCH 1/2] Don't retry when MQTT response is unauthorized --- adafruit_minimqtt/adafruit_minimqtt.py | 30 +++++++++++++++----- tests/test_backoff.py | 38 +++++++++++++++++++++++++- 2 files changed, 60 insertions(+), 8 deletions(-) diff --git a/adafruit_minimqtt/adafruit_minimqtt.py b/adafruit_minimqtt/adafruit_minimqtt.py index 622c6e68..82b43abb 100644 --- a/adafruit_minimqtt/adafruit_minimqtt.py +++ b/adafruit_minimqtt/adafruit_minimqtt.py @@ -72,12 +72,18 @@ MQTT_PKT_TYPE_MASK = const(0xF0) +CONNACK_ERROR_INCORRECT_PROTOCOL = const(0x01) +CONNACK_ERROR_ID_REJECTED = const(0x02) +CONNACK_ERROR_SERVER_UNAVAILABLE = const(0x03) +CONNACK_ERROR_INCORECT_USERNAME_PASSWORD = const(0x04) +CONNACK_ERROR_UNAUTHORIZED = const(0x05) + CONNACK_ERRORS = { - const(0x01): "Connection Refused - Incorrect Protocol Version", - const(0x02): "Connection Refused - ID Rejected", - const(0x03): "Connection Refused - Server unavailable", - const(0x04): "Connection Refused - Incorrect username/password", - const(0x05): "Connection Refused - Unauthorized", + CONNACK_ERROR_INCORRECT_PROTOCOL: "Connection Refused - Incorrect Protocol Version", + CONNACK_ERROR_ID_REJECTED: "Connection Refused - ID Rejected", + CONNACK_ERROR_SERVER_UNAVAILABLE: "Connection Refused - Server unavailable", + CONNACK_ERROR_INCORECT_USERNAME_PASSWORD: "Connection Refused - Incorrect username/password", + CONNACK_ERROR_UNAUTHORIZED: "Connection Refused - Unauthorized", } _default_sock = None # pylint: disable=invalid-name @@ -87,6 +93,10 @@ class MMQTTException(Exception): """MiniMQTT Exception class.""" + def __init__(self, error, code=None): + super().__init__(error, code) + self.code = code + class NullLogger: """Fake logger class that does not do anything""" @@ -428,8 +438,14 @@ def connect( self.logger.warning(f"Socket error when connecting: {e}") backoff = False except MMQTTException as e: - last_exception = e self.logger.info(f"MMQT error: {e}") + if e.code in [ + CONNACK_ERROR_INCORECT_USERNAME_PASSWORD, + CONNACK_ERROR_UNAUTHORIZED, + ]: + # No sense trying these again, re-raise + raise + last_exception = e backoff = True if self._reconnect_attempts_max > 1: @@ -535,7 +551,7 @@ def _connect( rc = self._sock_exact_recv(3) assert rc[0] == 0x02 if rc[2] != 0x00: - raise MMQTTException(CONNACK_ERRORS[rc[2]]) + raise MMQTTException(CONNACK_ERRORS[rc[2]], code=rc[2]) self._is_connected = True result = rc[0] & 1 if self.on_connect is not None: diff --git a/tests/test_backoff.py b/tests/test_backoff.py index e26d07a4..ce6097fa 100644 --- a/tests/test_backoff.py +++ b/tests/test_backoff.py @@ -18,18 +18,24 @@ class TestExpBackOff: """basic exponential back-off test""" connect_times = [] + raise_exception = None # pylint: disable=unused-argument def fake_connect(self, arg): """connect() replacement that records the call times and always raises OSError""" self.connect_times.append(time.monotonic()) - raise OSError("this connect failed") + raise self.raise_exception def test_failing_connect(self) -> None: """test that exponential back-off is used when connect() always raises OSError""" # use RFC 1918 address to avoid dealing with IPv6 in the call list below host = "172.40.0.3" port = 1883 + self.connect_times = [] + error_code = MQTT.CONNACK_ERROR_SERVER_UNAVAILABLE + self.raise_exception = MQTT.MMQTTException( + MQTT.CONNACK_ERRORS[error_code], code=error_code + ) with patch.object(socket.socket, "connect") as mock_method: mock_method.side_effect = self.fake_connect @@ -54,3 +60,33 @@ def test_failing_connect(self) -> None: print(f"connect() call times: {self.connect_times}") for i in range(1, connect_retries): assert self.connect_times[i] >= 2**i + + def test_unauthorized(self) -> None: + """test that exponential back-off is used when connect() always raises OSError""" + # use RFC 1918 address to avoid dealing with IPv6 in the call list below + host = "172.40.0.3" + port = 1883 + self.connect_times = [] + error_code = MQTT.CONNACK_ERROR_UNAUTHORIZED + self.raise_exception = MQTT.MMQTTException( + MQTT.CONNACK_ERRORS[error_code], code=error_code + ) + + with patch.object(socket.socket, "connect") as mock_method: + mock_method.side_effect = self.fake_connect + + connect_retries = 3 + mqtt_client = MQTT.MQTT( + broker=host, + port=port, + socket_pool=socket, + ssl_context=ssl.create_default_context(), + connect_retries=connect_retries, + ) + print("connecting") + with pytest.raises(MQTT.MMQTTException) as context: + mqtt_client.connect() + assert "Connection Refused - Unauthorized" in str(context) + + mock_method.assert_called() + assert len(self.connect_times) == 1 From 16b6c6db1e9a54e664c5eb9ef4f2848731c482a5 Mon Sep 17 00:00:00 2001 From: Justin Myers Date: Sun, 19 May 2024 21:14:10 -0700 Subject: [PATCH 2/2] Make sure socket is closed on exception --- adafruit_minimqtt/adafruit_minimqtt.py | 12 +++++++++--- tests/test_backoff.py | 2 ++ 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/adafruit_minimqtt/adafruit_minimqtt.py b/adafruit_minimqtt/adafruit_minimqtt.py index 82b43abb..22b1f80a 100644 --- a/adafruit_minimqtt/adafruit_minimqtt.py +++ b/adafruit_minimqtt/adafruit_minimqtt.py @@ -438,6 +438,7 @@ def connect( self.logger.warning(f"Socket error when connecting: {e}") backoff = False except MMQTTException as e: + self._close_socket() self.logger.info(f"MMQT error: {e}") if e.code in [ CONNACK_ERROR_INCORECT_USERNAME_PASSWORD, @@ -452,9 +453,9 @@ def connect( exc_msg = "Repeated connect failures" else: exc_msg = "Connect failure" + if last_exception: raise MMQTTException(exc_msg) from last_exception - raise MMQTTException(exc_msg) # pylint: disable=too-many-branches, too-many-statements, too-many-locals @@ -565,6 +566,12 @@ def _connect( f"No data received from broker for {self._recv_timeout} seconds." ) + def _close_socket(self): + if self._sock: + self.logger.debug("Closing socket") + self._connection_manager.close_socket(self._sock) + self._sock = None + # pylint: disable=no-self-use def _encode_remaining_length( self, fixed_header: bytearray, remaining_length: int @@ -593,8 +600,7 @@ def disconnect(self) -> None: self._sock.send(MQTT_DISCONNECT) except RuntimeError as e: self.logger.warning(f"Unable to send DISCONNECT packet: {e}") - self.logger.debug("Closing socket") - self._connection_manager.close_socket(self._sock) + self._close_socket() self._is_connected = False self._subscribed_topics = [] self._last_msg_sent_timestamp = 0 diff --git a/tests/test_backoff.py b/tests/test_backoff.py index ce6097fa..44d2da9d 100644 --- a/tests/test_backoff.py +++ b/tests/test_backoff.py @@ -51,6 +51,7 @@ def test_failing_connect(self) -> None: print("connecting") with pytest.raises(MQTT.MMQTTException) as context: mqtt_client.connect() + assert mqtt_client._sock is None assert "Repeated connect failures" in str(context) mock_method.assert_called() @@ -86,6 +87,7 @@ def test_unauthorized(self) -> None: print("connecting") with pytest.raises(MQTT.MMQTTException) as context: mqtt_client.connect() + assert mqtt_client._sock is None assert "Connection Refused - Unauthorized" in str(context) mock_method.assert_called()