Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

reconnect restoration #244

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions adafruit_minimqtt/adafruit_minimqtt.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,8 @@ def __init__( # noqa: PLR0915, PLR0913, Too many statements, Too many arguments
if port:
self.port = port

self.session_id = None

# define client identifier
if client_id:
# user-defined client_id MAY allow client_id's > 23 bytes or
Expand Down Expand Up @@ -528,6 +530,7 @@ def _connect( # noqa: PLR0912, PLR0913, PLR0915, Too many branches, Too many ar
is_ssl=self._is_ssl,
ssl_context=self._ssl_context,
)
self.session_id = session_id
self._backwards_compatible_sock = not hasattr(self._sock, "recv_into")

fixed_header = bytearray([0x10])
Expand Down Expand Up @@ -939,11 +942,18 @@ def reconnect(self, resub_topics: bool = True) -> int:
"""

self.logger.debug("Attempting to reconnect with MQTT broker")
ret = self.connect()
subscribed_topics = []
if self.is_connected():
# disconnect() will reset subscribed topics so stash them now.
if resub_topics:
subscribed_topics = self._subscribed_topics.copy()
self.disconnect()

ret = self.connect(session_id=self.session_id)
self.logger.debug("Reconnected with broker")
if resub_topics:

if resub_topics and subscribed_topics:
self.logger.debug("Attempting to resubscribe to previously subscribed topics.")
subscribed_topics = self._subscribed_topics.copy()
self._subscribed_topics = []
while subscribed_topics:
feed = subscribed_topics.pop()
Expand Down
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

@pytest.fixture(autouse=True)
def reset_connection_manager(monkeypatch):
"""Reset the ConnectionManager, since it's a singlton and will hold data"""
"""Reset the ConnectionManager, since it's a singleton and will hold data"""
monkeypatch.setattr(
"adafruit_minimqtt.adafruit_minimqtt.get_connection_manager",
adafruit_connection_manager.ConnectionManager,
Expand Down
247 changes: 247 additions & 0 deletions tests/test_reconnect.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,247 @@
# SPDX-FileCopyrightText: 2025 Vladimír Kotal
#
# SPDX-License-Identifier: Unlicense

"""reconnect tests"""

import logging
import ssl
import sys

import pytest
from mocket import Mocket

import adafruit_minimqtt.adafruit_minimqtt as MQTT

if not sys.implementation.name == "circuitpython":
from typing import Optional

from circuitpython_typing.socket import (
SocketType,
SSLContextType,
)


class FakeConnectionManager:
"""
Fake ConnectionManager class
"""

def __init__(self, socket):
self._socket = socket
self.close_cnt = 0

def get_socket( # noqa: PLR0913, Too many arguments
self,
host: str,
port: int,
proto: str,
session_id: Optional[str] = None,
*,
timeout: float = 1.0,
is_ssl: bool = False,
ssl_context: Optional[SSLContextType] = None,
) -> SocketType:
"""
Return the specified socket.
"""
return self._socket

def close_socket(self, socket) -> None:
self.close_cnt += 1


def handle_subscribe(client, user_data, topic, qos):
"""
Record topics into user data.
"""
assert topic
assert user_data["topics"] is not None
assert qos == 0

user_data["topics"].append(topic)


def handle_disconnect(client, user_data, zero):
"""
Record disconnect.
"""

user_data["disconnect"] = True


# The MQTT packet contents below were captured using Mosquitto client+server.
testdata = [
(
[],
bytearray(
[
0x20, # CONNACK
0x02,
0x00,
0x00,
0x90, # SUBACK
0x03,
0x00,
0x01,
0x00,
0x20, # CONNACK
0x02,
0x00,
0x00,
0x90, # SUBACK
0x03,
0x00,
0x02,
0x00,
]
),
),
(
[("foo/bar", 0)],
bytearray(
[
0x20, # CONNACK
0x02,
0x00,
0x00,
0x90, # SUBACK
0x03,
0x00,
0x01,
0x00,
0x20, # CONNACK
0x02,
0x00,
0x00,
0x90, # SUBACK
0x03,
0x00,
0x02,
0x00,
]
),
),
(
[("foo/bar", 0), ("bah", 0)],
bytearray(
[
0x20, # CONNACK
0x02,
0x00,
0x00,
0x90, # SUBACK
0x03,
0x00,
0x01,
0x00,
0x00,
0x20, # CONNACK
0x02,
0x00,
0x00,
0x90, # SUBACK
0x03,
0x00,
0x02,
0x00,
0x90, # SUBACK
0x03,
0x00,
0x03,
0x00,
]
),
),
]


@pytest.mark.parametrize(
"topics,to_send",
testdata,
ids=[
"no_topic",
"single_topic",
"multi_topic",
],
)
def test_reconnect(topics, to_send) -> None:
"""
Test reconnect() handling, mainly that it performs disconnect on already connected socket.

Nothing will travel over the wire, it is all fake.
"""
logging.basicConfig()
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)

host = "localhost"
port = 1883

user_data = {"topics": [], "disconnect": False}
mqtt_client = MQTT.MQTT(
broker=host,
port=port,
ssl_context=ssl.create_default_context(),
connect_retries=1,
user_data=user_data,
)

mocket = Mocket(to_send)
mqtt_client._connection_manager = FakeConnectionManager(mocket)
mqtt_client.connect()

mqtt_client.logger = logger

if topics:
logger.info(f"subscribing to {topics}")
mqtt_client.subscribe(topics)

logger.info("reconnecting")
mqtt_client.on_subscribe = handle_subscribe
mqtt_client.on_disconnect = handle_disconnect
mqtt_client.reconnect()

assert user_data.get("disconnect") == True
assert mqtt_client._connection_manager.close_cnt == 1
assert set(user_data.get("topics")) == set([t[0] for t in topics])


def test_reconnect_not_connected() -> None:
"""
Test reconnect() handling not connected.
"""
logging.basicConfig()
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)

host = "localhost"
port = 1883

user_data = {"topics": [], "disconnect": False}
mqtt_client = MQTT.MQTT(
broker=host,
port=port,
ssl_context=ssl.create_default_context(),
connect_retries=1,
user_data=user_data,
)

mocket = Mocket(
bytearray(
[
0x20, # CONNACK
0x02,
0x00,
0x00,
]
)
)
mqtt_client._connection_manager = FakeConnectionManager(mocket)

mqtt_client.logger = logger
mqtt_client.on_disconnect = handle_disconnect
mqtt_client.reconnect()

assert user_data.get("disconnect") == False
assert mqtt_client._connection_manager.close_cnt == 0