Skip to content

Commit

Permalink
reduce the use of MMQTTException
Browse files Browse the repository at this point in the history
use it for protocol/network/system level errors only

fixes adafruit#201
  • Loading branch information
vladak committed Jan 28, 2025
1 parent 57ed4f0 commit fdd436e
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 29 deletions.
63 changes: 38 additions & 25 deletions adafruit_minimqtt/adafruit_minimqtt.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,13 +93,26 @@


class MMQTTException(Exception):
"""MiniMQTT Exception class."""
"""
MiniMQTT Exception class.
Raised for various mostly protocol or network/system level errors.
In general, the robust way to recover is to call reconnect().
"""

def __init__(self, error, code=None):
super().__init__(error, code)
self.code = code


class MMQTTStateError(MMQTTException):
"""
MiniMQTT invalid state error.
Raised e.g. if a function is called in unexpected state.
"""


class NullLogger:
"""Fake logger class that does not do anything"""

Expand Down Expand Up @@ -163,7 +176,7 @@ def __init__( # noqa: PLR0915, PLR0913, Too many statements, Too many arguments
self._use_binary_mode = use_binary_mode

if recv_timeout <= socket_timeout:
raise MMQTTException("recv_timeout must be strictly greater than socket_timeout")
raise ValueError("recv_timeout must be strictly greater than socket_timeout")
self._socket_timeout = socket_timeout
self._recv_timeout = recv_timeout

Expand All @@ -181,7 +194,7 @@ def __init__( # noqa: PLR0915, PLR0913, Too many statements, Too many arguments
self._reconnect_timeout = float(0)
self._reconnect_maximum_backoff = 32
if connect_retries <= 0:
raise MMQTTException("connect_retries must be positive")
raise ValueError("connect_retries must be positive")
self._reconnect_attempts_max = connect_retries

self.broker = broker
Expand All @@ -190,7 +203,7 @@ def __init__( # noqa: PLR0915, PLR0913, Too many statements, Too many arguments
if (
self._password and len(password.encode("utf-8")) > MQTT_TOPIC_LENGTH_LIMIT
): # [MQTT-3.1.3.5]
raise MMQTTException("Password length is too large.")
raise ValueError("Password length is too large.")

# The connection will be insecure unless is_ssl is set to True.
# If the port is not specified, the security will be set based on the is_ssl parameter.
Expand Down Expand Up @@ -286,28 +299,27 @@ def will_set(
"""
self.logger.debug("Setting last will properties")
if self._is_connected:
raise MMQTTException("Last Will should only be called before connect().")
raise MMQTTStateError("Last Will should only be called before connect().")

# check topic/msg/qos kwargs
self._valid_topic(topic)
if "+" in topic or "#" in topic:
raise MMQTTException("Publish topic can not contain wildcards.")
raise ValueError("Publish topic can not contain wildcards.")

if msg is None:
raise MMQTTException("Message can not be None.")
raise ValueError("Message can not be None.")
if isinstance(msg, (int, float)):
msg = str(msg).encode("ascii")
elif isinstance(msg, str):
msg = str(msg).encode("utf-8")
elif isinstance(msg, bytes):
pass
else:
raise MMQTTException("Invalid message data type.")
raise ValueError("Invalid message data type.")
if len(msg) > MQTT_MSG_MAX_SZ:
raise MMQTTException(f"Message size larger than {MQTT_MSG_MAX_SZ} bytes.")
raise ValueError(f"Message size larger than {MQTT_MSG_MAX_SZ} bytes.")

self._valid_qos(qos)
assert 0 <= qos <= 1, "Quality of Service Level 2 is unsupported by this library."

# fixed header. [3.3.1.2], [3.3.1.3]
pub_hdr_fixed = bytearray([MQTT_PUBLISH | retain | qos << 1])
Expand Down Expand Up @@ -390,7 +402,7 @@ def username_pw_set(self, username: str, password: Optional[str] = None) -> None
"""
if self._is_connected:
raise MMQTTException("This method must be called before connect().")
raise MMQTTStateError("This method must be called before connect().")
self._username = username
if password is not None:
self._password = password
Expand Down Expand Up @@ -670,21 +682,22 @@ def publish( # noqa: PLR0912, Too many branches
self._connected()
self._valid_topic(topic)
if "+" in topic or "#" in topic:
raise MMQTTException("Publish topic can not contain wildcards.")
raise ValueError("Publish topic can not contain wildcards.")
# check msg/qos kwargs
if msg is None:
raise MMQTTException("Message can not be None.")
raise ValueError("Message can not be None.")
if isinstance(msg, (int, float)):
msg = str(msg).encode("ascii")
elif isinstance(msg, str):
msg = str(msg).encode("utf-8")
elif isinstance(msg, bytes):
pass
else:
raise MMQTTException("Invalid message data type.")
raise ValueError("Invalid message data type.")
if len(msg) > MQTT_MSG_MAX_SZ:
raise MMQTTException(f"Message size larger than {MQTT_MSG_MAX_SZ} bytes.")
assert 0 <= qos <= 1, "Quality of Service Level 2 is unsupported by this library."
raise ValueError(f"Message size larger than {MQTT_MSG_MAX_SZ} bytes.")

self._valid_qos(qos)

# fixed header. [3.3.1.2], [3.3.1.3]
pub_hdr_fixed = bytearray([MQTT_PUBLISH | retain | qos << 1])
Expand Down Expand Up @@ -849,7 +862,7 @@ def unsubscribe( # noqa: PLR0912, Too many branches
topics.append(t)
for t in topics:
if t not in self._subscribed_topics:
raise MMQTTException("Topic must be subscribed to before attempting unsubscribe.")
raise MMQTTStateError("Topic must be subscribed to before attempting unsubscribe.")
# Assemble packet
self.logger.debug("Sending UNSUBSCRIBE to broker...")
fixed_header = bytearray([MQTT_UNSUB])
Expand Down Expand Up @@ -959,7 +972,7 @@ def loop(self, timeout: float = 1.0) -> Optional[list[int]]:
"""
if timeout < self._socket_timeout:
raise MMQTTException(
raise ValueError(
f"loop timeout ({timeout}) must be >= "
+ f"socket timeout ({self._socket_timeout}))"
)
Expand Down Expand Up @@ -1153,13 +1166,13 @@ def _valid_topic(topic: str) -> None:
"""
if topic is None:
raise MMQTTException("Topic may not be NoneType")
raise ValueError("Topic may not be NoneType")
# [MQTT-4.7.3-1]
if not topic:
raise MMQTTException("Topic may not be empty.")
raise ValueError("Topic may not be empty.")
# [MQTT-4.7.3-3]
if len(topic.encode("utf-8")) > MQTT_TOPIC_LENGTH_LIMIT:
raise MMQTTException("Topic length is too large.")
raise ValueError(f"Encoded topic length is larger than {MQTT_TOPIC_LENGTH_LIMIT}")

@staticmethod
def _valid_qos(qos_level: int) -> None:
Expand All @@ -1170,16 +1183,16 @@ def _valid_qos(qos_level: int) -> None:
"""
if isinstance(qos_level, int):
if qos_level < 0 or qos_level > 2:
raise MMQTTException("QoS must be between 1 and 2.")
raise NotImplementedError("QoS must be between 1 and 2.")
else:
raise MMQTTException("QoS must be an integer.")
raise ValueError("QoS must be an integer.")

def _connected(self) -> None:
"""Returns MQTT client session status as True if connected, raises
a `MMQTTException` if `False`.
a `MMQTTStateError exception` if `False`.
"""
if not self.is_connected():
raise MMQTTException("MiniMQTT is not connected")
raise MMQTTStateError("MiniMQTT is not connected")

def is_connected(self) -> bool:
"""Returns MQTT client session status as True if connected, False
Expand Down
8 changes: 4 additions & 4 deletions tests/test_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def test_loop_basic(self) -> None:

def test_loop_timeout_vs_socket_timeout(self):
"""
loop() should throw MMQTTException if the timeout argument
loop() should throw ValueError if the timeout argument
is bigger than the socket timeout.
"""
mqtt_client = MQTT.MQTT(
Expand All @@ -167,14 +167,14 @@ def test_loop_timeout_vs_socket_timeout(self):
)

mqtt_client.is_connected = lambda: True
with pytest.raises(MQTT.MMQTTException) as context:
with pytest.raises(ValueError) as context:
mqtt_client.loop(timeout=0.5)

assert "loop timeout" in str(context)

def test_loop_is_connected(self):
"""
loop() should throw MMQTTException if not connected
loop() should throw MMQTTStateError if not connected
"""
mqtt_client = MQTT.MQTT(
broker="127.0.0.1",
Expand All @@ -183,7 +183,7 @@ def test_loop_is_connected(self):
ssl_context=ssl.create_default_context(),
)

with pytest.raises(MQTT.MMQTTException) as context:
with pytest.raises(MQTT.MMQTTStateError) as context:
mqtt_client.loop(timeout=1)

assert "not connected" in str(context)
Expand Down

0 comments on commit fdd436e

Please sign in to comment.