Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

improve ping handling #199

Merged
merged 3 commits into from
Jan 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 23 additions & 14 deletions adafruit_minimqtt/adafruit_minimqtt.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ def __init__(
self._is_connected = False
self._msg_size_lim = MQTT_MSG_SZ_LIM
self._pid = 0
self._timestamp: float = 0
self._last_msg_sent_timestamp: float = 0
self.logger = NullLogger()
"""An optional logging attribute that can be set with with a Logger
to enable debug logging."""
Expand Down Expand Up @@ -640,6 +640,7 @@ def _connect(
if self._username is not None:
self._send_str(self._username)
self._send_str(self._password)
self._last_msg_sent_timestamp = self.get_monotonic_time()
self.logger.debug("Receiving CONNACK packet from broker")
stamp = self.get_monotonic_time()
while True:
Expand Down Expand Up @@ -694,6 +695,7 @@ def disconnect(self) -> None:
self._sock.close()
self._is_connected = False
self._subscribed_topics = []
self._last_msg_sent_timestamp = 0
if self.on_disconnect is not None:
self.on_disconnect(self, self.user_data, 0)

Expand All @@ -707,6 +709,7 @@ def ping(self) -> list[int]:
self._sock.send(MQTT_PINGREQ)
ping_timeout = self.keep_alive
stamp = self.get_monotonic_time()
self._last_msg_sent_timestamp = stamp
rc, rcs = None, []
while rc != MQTT_PINGRESP:
rc = self._wait_for_msg()
Expand Down Expand Up @@ -781,6 +784,7 @@ def publish(
self._sock.send(pub_hdr_fixed)
self._sock.send(pub_hdr_var)
self._sock.send(msg)
self._last_msg_sent_timestamp = self.get_monotonic_time()
if qos == 0 and self.on_publish is not None:
self.on_publish(self, self.user_data, topic, self._pid)
if qos == 1:
Expand Down Expand Up @@ -858,6 +862,7 @@ def subscribe(self, topic: Optional[Union[tuple, str, list]], qos: int = 0) -> N
self.logger.debug(f"payload: {payload}")
self._sock.send(payload)
stamp = self.get_monotonic_time()
self._last_msg_sent_timestamp = stamp
while True:
op = self._wait_for_msg()
if op is None:
Expand Down Expand Up @@ -933,6 +938,7 @@ def unsubscribe(self, topic: Optional[Union[str, list]]) -> None:
for t in topics:
self.logger.debug(f"UNSUBSCRIBING from topic {t}")
self._sock.send(payload)
self._last_msg_sent_timestamp = self.get_monotonic_time()
self.logger.debug("Waiting for UNSUBACK...")
while True:
stamp = self.get_monotonic_time()
Expand Down Expand Up @@ -1022,7 +1028,6 @@ def reconnect(self, resub_topics: bool = True) -> int:
return ret

def loop(self, timeout: float = 0) -> Optional[list[int]]:
# pylint: disable = too-many-return-statements
"""Non-blocking message loop. Use this method to check for incoming messages.
Returns list of packet types of any messages received or None.

Expand All @@ -1038,23 +1043,27 @@ def loop(self, timeout: float = 0) -> Optional[list[int]]:

self._connected()
self.logger.debug(f"waiting for messages for {timeout} seconds")
if self._timestamp == 0:
self._timestamp = self.get_monotonic_time()
current_time = self.get_monotonic_time()
if current_time - self._timestamp >= self.keep_alive:
self._timestamp = 0
# Handle KeepAlive by expecting a PINGREQ/PINGRESP from the server
self.logger.debug(
"KeepAlive period elapsed - requesting a PINGRESP from the server..."
)
rcs = self.ping()
return rcs

stamp = self.get_monotonic_time()
rcs = []

while True:
rc = self._wait_for_msg(timeout=timeout)
if (
self.get_monotonic_time() - self._last_msg_sent_timestamp
>= self.keep_alive
):
# Handle KeepAlive by expecting a PINGREQ/PINGRESP from the server
self.logger.debug(
"KeepAlive period elapsed - requesting a PINGRESP from the server..."
)
rcs.extend(self.ping())
# ping() itself contains a _wait_for_msg() loop which might have taken a while,
# so check here as well.
if self.get_monotonic_time() - stamp > timeout:
self.logger.debug(f"Loop timed out after {timeout} seconds")
break

rc = self._wait_for_msg()
if rc is not None:
rcs.append(rc)
if self.get_monotonic_time() - stamp > timeout:
Expand Down
155 changes: 155 additions & 0 deletions tests/test_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,99 @@
import socket
import ssl
import time
import errno

from unittest import TestCase, main
from unittest.mock import patch
from unittest import mock

import adafruit_minimqtt.adafruit_minimqtt as MQTT


class Nulltet:
"""
Mock Socket that does nothing.

Inspired by the Mocket class from Adafruit_CircuitPython_Requests
"""

def __init__(self):
self.sent = bytearray()

self.timeout = mock.Mock()
self.connect = mock.Mock()
self.close = mock.Mock()

def send(self, bytes_to_send):
"""
Record the bytes. return the length of this bytearray.
"""
self.sent.extend(bytes_to_send)
return len(bytes_to_send)

# MiniMQTT checks for the presence of "recv_into" and switches behavior based on that.
# pylint: disable=unused-argument,no-self-use
def recv_into(self, retbuf, bufsize):
"""Always raise timeout exception."""
exc = OSError()
exc.errno = errno.ETIMEDOUT
raise exc


class Pingtet:
"""
Mock Socket tailored for PINGREQ testing.
Records sent data, hands out PINGRESP for each PINGREQ received.

Inspired by the Mocket class from Adafruit_CircuitPython_Requests
"""

PINGRESP = bytearray([0xD0, 0x00])

def __init__(self):
self._to_send = self.PINGRESP

self.sent = bytearray()

self.timeout = mock.Mock()
self.connect = mock.Mock()
self.close = mock.Mock()

self._got_pingreq = False

def send(self, bytes_to_send):
"""
Recognize PINGREQ and record the indication that it was received.
Assumes it was sent in one chunk (of 2 bytes).
Also record the bytes. return the length of this bytearray.
"""
self.sent.extend(bytes_to_send)
if bytes_to_send == b"\xc0\0":
self._got_pingreq = True
return len(bytes_to_send)

# MiniMQTT checks for the presence of "recv_into" and switches behavior based on that.
def recv_into(self, retbuf, bufsize):
"""
If the PINGREQ indication is on, return PINGRESP, otherwise raise timeout exception.
"""
if self._got_pingreq:
size = min(bufsize, len(self._to_send))
if size == 0:
return size
chop = self._to_send[0:size]
retbuf[0:] = chop
self._to_send = self._to_send[size:]
if len(self._to_send) == 0:
self._got_pingreq = False
self._to_send = self.PINGRESP
return size

exc = OSError()
exc.errno = errno.ETIMEDOUT
raise exc


class Loop(TestCase):
"""basic loop() test"""

Expand Down Expand Up @@ -54,6 +141,8 @@ def test_loop_basic(self) -> None:

time_before = time.monotonic()
timeout = random.randint(3, 8)
# pylint: disable=protected-access
mqtt_client._last_msg_sent_timestamp = mqtt_client.get_monotonic_time()
rcs = mqtt_client.loop(timeout=timeout)
time_after = time.monotonic()

Expand All @@ -64,6 +153,7 @@ def test_loop_basic(self) -> None:
assert rcs is not None
assert len(rcs) >= 1
expected_rc = self.INITIAL_RCS_VAL
# pylint: disable=not-an-iterable
for ret_code in rcs:
assert ret_code == expected_rc
expected_rc += 1
Expand Down Expand Up @@ -104,6 +194,71 @@ def test_loop_is_connected(self):

assert "not connected" in str(context.exception)

# pylint: disable=no-self-use
def test_loop_ping_timeout(self):
"""Verify that ping will be sent even with loop timeout bigger than keep alive timeout
and no outgoing messages are sent."""

recv_timeout = 2
keep_alive_timeout = recv_timeout * 2
mqtt_client = MQTT.MQTT(
broker="localhost",
port=1883,
ssl_context=ssl.create_default_context(),
connect_retries=1,
socket_timeout=1,
recv_timeout=recv_timeout,
keep_alive=keep_alive_timeout,
)

# patch is_connected() to avoid CONNECT/CONNACK handling.
mqtt_client.is_connected = lambda: True
mocket = Pingtet()
# pylint: disable=protected-access
mqtt_client._sock = mocket

start = time.monotonic()
res = mqtt_client.loop(timeout=2 * keep_alive_timeout)
assert time.monotonic() - start >= 2 * keep_alive_timeout
assert len(mocket.sent) > 0
assert len(res) == 2
assert set(res) == {int(0xD0)}

# pylint: disable=no-self-use
def test_loop_ping_vs_msgs_sent(self):
"""Verify that ping will not be sent unnecessarily."""

recv_timeout = 2
keep_alive_timeout = recv_timeout * 2
mqtt_client = MQTT.MQTT(
broker="localhost",
port=1883,
ssl_context=ssl.create_default_context(),
connect_retries=1,
socket_timeout=1,
recv_timeout=recv_timeout,
keep_alive=keep_alive_timeout,
)

# patch is_connected() to avoid CONNECT/CONNACK handling.
mqtt_client.is_connected = lambda: True

# With QoS=0 no PUBACK message is sent, so Nulltet can be used.
mocket = Nulltet()
# pylint: disable=protected-access
mqtt_client._sock = mocket

i = 0
topic = "foo"
message = "bar"
for _ in range(3 * keep_alive_timeout):
mqtt_client.publish(topic, message, qos=0)
mqtt_client.loop(1)
i += 1

# This means no other messages than the PUBLISH messages generated by the code above.
assert len(mocket.sent) == i * (2 + 2 + len(topic) + len(message))


if __name__ == "__main__":
main()
Loading