From 16b6c6db1e9a54e664c5eb9ef4f2848731c482a5 Mon Sep 17 00:00:00 2001 From: Justin Myers Date: Sun, 19 May 2024 21:14:10 -0700 Subject: [PATCH] Make sure socket is closed on exception --- adafruit_minimqtt/adafruit_minimqtt.py | 12 +++++++++--- tests/test_backoff.py | 2 ++ 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/adafruit_minimqtt/adafruit_minimqtt.py b/adafruit_minimqtt/adafruit_minimqtt.py index 82b43abb..22b1f80a 100644 --- a/adafruit_minimqtt/adafruit_minimqtt.py +++ b/adafruit_minimqtt/adafruit_minimqtt.py @@ -438,6 +438,7 @@ def connect( self.logger.warning(f"Socket error when connecting: {e}") backoff = False except MMQTTException as e: + self._close_socket() self.logger.info(f"MMQT error: {e}") if e.code in [ CONNACK_ERROR_INCORECT_USERNAME_PASSWORD, @@ -452,9 +453,9 @@ def connect( exc_msg = "Repeated connect failures" else: exc_msg = "Connect failure" + if last_exception: raise MMQTTException(exc_msg) from last_exception - raise MMQTTException(exc_msg) # pylint: disable=too-many-branches, too-many-statements, too-many-locals @@ -565,6 +566,12 @@ def _connect( f"No data received from broker for {self._recv_timeout} seconds." ) + def _close_socket(self): + if self._sock: + self.logger.debug("Closing socket") + self._connection_manager.close_socket(self._sock) + self._sock = None + # pylint: disable=no-self-use def _encode_remaining_length( self, fixed_header: bytearray, remaining_length: int @@ -593,8 +600,7 @@ def disconnect(self) -> None: self._sock.send(MQTT_DISCONNECT) except RuntimeError as e: self.logger.warning(f"Unable to send DISCONNECT packet: {e}") - self.logger.debug("Closing socket") - self._connection_manager.close_socket(self._sock) + self._close_socket() self._is_connected = False self._subscribed_topics = [] self._last_msg_sent_timestamp = 0 diff --git a/tests/test_backoff.py b/tests/test_backoff.py index ce6097fa..44d2da9d 100644 --- a/tests/test_backoff.py +++ b/tests/test_backoff.py @@ -51,6 +51,7 @@ def test_failing_connect(self) -> None: print("connecting") with pytest.raises(MQTT.MMQTTException) as context: mqtt_client.connect() + assert mqtt_client._sock is None assert "Repeated connect failures" in str(context) mock_method.assert_called() @@ -86,6 +87,7 @@ def test_unauthorized(self) -> None: print("connecting") with pytest.raises(MQTT.MMQTTException) as context: mqtt_client.connect() + assert mqtt_client._sock is None assert "Connection Refused - Unauthorized" in str(context) mock_method.assert_called()