Skip to content

Commit

Permalink
reduce the use of MMQTTException
Browse files Browse the repository at this point in the history
use it for protocol/network/system level errors only

fixes adafruit#201
  • Loading branch information
vladak committed Jan 28, 2025
1 parent 57ed4f0 commit 315f3df
Showing 1 changed file with 74 additions and 34 deletions.
108 changes: 74 additions & 34 deletions adafruit_minimqtt/adafruit_minimqtt.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,13 +93,26 @@


class MMQTTException(Exception):
"""MiniMQTT Exception class."""
"""
MiniMQTT Exception class.
Raised for various mostly protocol or network/system level errors.
In general, the robust way to recover is to call reconnect().
"""

def __init__(self, error, code=None):
super().__init__(error, code)
self.code = code


class MMQTTStateError(MMQTTException):
"""
MiniMQTT invalid state error.
Raised e.g. if a function is called in unexpected state.
"""


class NullLogger:
"""Fake logger class that does not do anything"""

Expand Down Expand Up @@ -163,7 +176,9 @@ def __init__( # noqa: PLR0915, PLR0913, Too many statements, Too many arguments
self._use_binary_mode = use_binary_mode

if recv_timeout <= socket_timeout:
raise MMQTTException("recv_timeout must be strictly greater than socket_timeout")
raise ValueError(
"recv_timeout must be strictly greater than socket_timeout"
)
self._socket_timeout = socket_timeout
self._recv_timeout = recv_timeout

Expand All @@ -181,7 +196,7 @@ def __init__( # noqa: PLR0915, PLR0913, Too many statements, Too many arguments
self._reconnect_timeout = float(0)
self._reconnect_maximum_backoff = 32
if connect_retries <= 0:
raise MMQTTException("connect_retries must be positive")
raise ValueError("connect_retries must be positive")
self._reconnect_attempts_max = connect_retries

self.broker = broker
Expand All @@ -190,7 +205,7 @@ def __init__( # noqa: PLR0915, PLR0913, Too many statements, Too many arguments
if (
self._password and len(password.encode("utf-8")) > MQTT_TOPIC_LENGTH_LIMIT
): # [MQTT-3.1.3.5]
raise MMQTTException("Password length is too large.")
raise ValueError("Password length is too large.")

# The connection will be insecure unless is_ssl is set to True.
# If the port is not specified, the security will be set based on the is_ssl parameter.
Expand Down Expand Up @@ -286,28 +301,27 @@ def will_set(
"""
self.logger.debug("Setting last will properties")
if self._is_connected:
raise MMQTTException("Last Will should only be called before connect().")
raise MMQTTStateError("Last Will should only be called before connect().")

# check topic/msg/qos kwargs
self._valid_topic(topic)
if "+" in topic or "#" in topic:
raise MMQTTException("Publish topic can not contain wildcards.")
raise ValueError("Publish topic can not contain wildcards.")

if msg is None:
raise MMQTTException("Message can not be None.")
raise ValueError("Message can not be None.")
if isinstance(msg, (int, float)):
msg = str(msg).encode("ascii")
elif isinstance(msg, str):
msg = str(msg).encode("utf-8")
elif isinstance(msg, bytes):
pass
else:
raise MMQTTException("Invalid message data type.")
raise ValueError("Invalid message data type.")
if len(msg) > MQTT_MSG_MAX_SZ:
raise MMQTTException(f"Message size larger than {MQTT_MSG_MAX_SZ} bytes.")
raise ValueError(f"Message size larger than {MQTT_MSG_MAX_SZ} bytes.")

self._valid_qos(qos)
assert 0 <= qos <= 1, "Quality of Service Level 2 is unsupported by this library."

# fixed header. [3.3.1.2], [3.3.1.3]
pub_hdr_fixed = bytearray([MQTT_PUBLISH | retain | qos << 1])
Expand Down Expand Up @@ -357,7 +371,9 @@ def remove_topic_callback(self, mqtt_topic: str) -> None:
try:
del self._on_message_filtered[mqtt_topic]
except KeyError:
raise KeyError("MQTT topic callback not added with add_topic_callback.") from None
raise KeyError(
"MQTT topic callback not added with add_topic_callback."
) from None

@property
def on_message(self):
Expand Down Expand Up @@ -390,7 +406,7 @@ def username_pw_set(self, username: str, password: Optional[str] = None) -> None
"""
if self._is_connected:
raise MMQTTException("This method must be called before connect().")
raise MMQTTStateError("This method must be called before connect().")
self._username = username
if password is not None:
self._password = password
Expand Down Expand Up @@ -541,15 +557,20 @@ def _connect( # noqa: PLR0912, PLR0913, PLR0915, Too many branches, Too many ar
remaining_length = 12 + len(self.client_id.encode("utf-8"))
if self._username is not None:
remaining_length += (
2 + len(self._username.encode("utf-8")) + 2 + len(self._password.encode("utf-8"))
2
+ len(self._username.encode("utf-8"))
+ 2
+ len(self._password.encode("utf-8"))
)
var_header[7] |= 0xC0
if self.keep_alive:
assert self.keep_alive < MQTT_TOPIC_LENGTH_LIMIT
var_header[8] |= self.keep_alive >> 8
var_header[9] |= self.keep_alive & 0x00FF
if self._lw_topic:
remaining_length += 2 + len(self._lw_topic.encode("utf-8")) + 2 + len(self._lw_msg)
remaining_length += (
2 + len(self._lw_topic.encode("utf-8")) + 2 + len(self._lw_msg)
)
var_header[7] |= 0x4 | (self._lw_qos & 0x1) << 3 | (self._lw_qos & 0x2) << 3
var_header[7] |= self._lw_retain << 5

Expand Down Expand Up @@ -597,7 +618,9 @@ def _close_socket(self):
self._connection_manager.close_socket(self._sock)
self._sock = None

def _encode_remaining_length(self, fixed_header: bytearray, remaining_length: int) -> None:
def _encode_remaining_length(
self, fixed_header: bytearray, remaining_length: int
) -> None:
"""Encode Remaining Length [2.2.3]"""
if remaining_length > 268_435_455:
raise MMQTTException("invalid remaining length")
Expand Down Expand Up @@ -670,21 +693,23 @@ def publish( # noqa: PLR0912, Too many branches
self._connected()
self._valid_topic(topic)
if "+" in topic or "#" in topic:
raise MMQTTException("Publish topic can not contain wildcards.")
raise ValueError("Publish topic can not contain wildcards.")
# check msg/qos kwargs
if msg is None:
raise MMQTTException("Message can not be None.")
raise ValueError("Message can not be None.")
if isinstance(msg, (int, float)):
msg = str(msg).encode("ascii")
elif isinstance(msg, str):
msg = str(msg).encode("utf-8")
elif isinstance(msg, bytes):
pass
else:
raise MMQTTException("Invalid message data type.")
raise ValueError("Invalid message data type.")
if len(msg) > MQTT_MSG_MAX_SZ:
raise MMQTTException(f"Message size larger than {MQTT_MSG_MAX_SZ} bytes.")
assert 0 <= qos <= 1, "Quality of Service Level 2 is unsupported by this library."
raise ValueError(f"Message size larger than {MQTT_MSG_MAX_SZ} bytes.")
assert (
0 <= qos <= 1
), "Quality of Service Level 2 is unsupported by this library."

# fixed header. [3.3.1.2], [3.3.1.3]
pub_hdr_fixed = bytearray([MQTT_PUBLISH | retain | qos << 1])
Expand Down Expand Up @@ -849,7 +874,9 @@ def unsubscribe( # noqa: PLR0912, Too many branches
topics.append(t)
for t in topics:
if t not in self._subscribed_topics:
raise MMQTTException("Topic must be subscribed to before attempting unsubscribe.")
raise MMQTTStateError(
"Topic must be subscribed to before attempting unsubscribe."
)
# Assemble packet
self.logger.debug("Sending UNSUBSCRIBE to broker...")
fixed_header = bytearray([MQTT_UNSUB])
Expand Down Expand Up @@ -906,7 +933,9 @@ def _recompute_reconnect_backoff(self) -> None:
"""
self._reconnect_attempt = self._reconnect_attempt + 1
self._reconnect_timeout = 2**self._reconnect_attempt
self.logger.debug(f"Reconnect timeout computed to {self._reconnect_timeout:.2f}")
self.logger.debug(
f"Reconnect timeout computed to {self._reconnect_timeout:.2f}"
)

if self._reconnect_timeout > self._reconnect_maximum_backoff:
self.logger.debug(
Expand All @@ -917,7 +946,9 @@ def _recompute_reconnect_backoff(self) -> None:
# Add a sub-second jitter.
# Even truncated timeout should have jitter added to it. This is why it is added here.
jitter = randint(0, 1000) / 1000
self.logger.debug(f"adding jitter {jitter:.2f} to {self._reconnect_timeout:.2f} seconds")
self.logger.debug(
f"adding jitter {jitter:.2f} to {self._reconnect_timeout:.2f} seconds"
)
self._reconnect_timeout += jitter

def _reset_reconnect_backoff(self) -> None:
Expand All @@ -942,7 +973,9 @@ def reconnect(self, resub_topics: bool = True) -> int:
ret = self.connect()
self.logger.debug("Reconnected with broker")
if resub_topics:
self.logger.debug("Attempting to resubscribe to previously subscribed topics.")
self.logger.debug(
"Attempting to resubscribe to previously subscribed topics."
)
subscribed_topics = self._subscribed_topics.copy()
self._subscribed_topics = []
while subscribed_topics:
Expand All @@ -959,7 +992,7 @@ def loop(self, timeout: float = 1.0) -> Optional[list[int]]:
"""
if timeout < self._socket_timeout:
raise MMQTTException(
raise ValueError(
f"loop timeout ({timeout}) must be >= "
+ f"socket timeout ({self._socket_timeout}))"
)
Expand All @@ -971,7 +1004,10 @@ def loop(self, timeout: float = 1.0) -> Optional[list[int]]:
rcs = []

while True:
if ticks_diff(ticks_ms(), self._last_msg_sent_timestamp) / 1000 >= self.keep_alive:
if (
ticks_diff(ticks_ms(), self._last_msg_sent_timestamp) / 1000
>= self.keep_alive
):
# Handle KeepAlive by expecting a PINGREQ/PINGRESP from the server
self.logger.debug(
"KeepAlive period elapsed - requesting a PINGRESP from the server..."
Expand Down Expand Up @@ -1077,7 +1113,9 @@ def _decode_remaining_length(self) -> int:
return n
sh += 7

def _sock_exact_recv(self, bufsize: int, timeout: Optional[float] = None) -> bytearray:
def _sock_exact_recv(
self, bufsize: int, timeout: Optional[float] = None
) -> bytearray:
"""Reads _exact_ number of bytes from the connected socket. Will only return
bytearray with the exact number of bytes requested.
Expand Down Expand Up @@ -1153,13 +1191,15 @@ def _valid_topic(topic: str) -> None:
"""
if topic is None:
raise MMQTTException("Topic may not be NoneType")
raise ValueError("Topic may not be NoneType")
# [MQTT-4.7.3-1]
if not topic:
raise MMQTTException("Topic may not be empty.")
raise ValueError("Topic may not be empty.")
# [MQTT-4.7.3-3]
if len(topic.encode("utf-8")) > MQTT_TOPIC_LENGTH_LIMIT:
raise MMQTTException("Topic length is too large.")
raise ValueError(
f"Encoded topic length is larger than {MQTT_TOPIC_LENGTH_LIMIT}"
)

@staticmethod
def _valid_qos(qos_level: int) -> None:
Expand All @@ -1170,16 +1210,16 @@ def _valid_qos(qos_level: int) -> None:
"""
if isinstance(qos_level, int):
if qos_level < 0 or qos_level > 2:
raise MMQTTException("QoS must be between 1 and 2.")
raise NotImplementedError("QoS must be between 1 and 2.")
else:
raise MMQTTException("QoS must be an integer.")
raise ValueError("QoS must be an integer.")

def _connected(self) -> None:
"""Returns MQTT client session status as True if connected, raises
a `MMQTTException` if `False`.
a `MMQTTStateError exception` if `False`.
"""
if not self.is_connected():
raise MMQTTException("MiniMQTT is not connected")
raise MMQTTStateError("MiniMQTT is not connected")

def is_connected(self) -> bool:
"""Returns MQTT client session status as True if connected, False
Expand Down

0 comments on commit 315f3df

Please sign in to comment.