Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions src/gateway/binance/ws/threaded_stream.py
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self._loop: asyncio.AbstractEventLoop = get_loop() if _loop is None else _loop
self._loop: asyncio.AbstractEventLoop = None # Initialize as None, created in the run method
why? how do you test the code?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An asyncio event loop is not thread-safe. It is designed to be owned and managed by a single thread. The original code assigned the event loop in the init method, which runs in the main thread. The run() method, however, executes in a new, separate worker thread. Passing an event loop from one thread to another to be run can lead to race conditions and other unpredictable concurrency issues.

The best practice, which the new code follows, is that the thread that will run the event loop should be the one to create, manage, and eventually close it. By creating the loop inside the run() method, we ensure the new thread has its own private, isolated event loop, making the entire manager much more robust and predictable.

I added the test unit test_threaded_stream.py to test ThreadedApiManager by verifying that the thread can be started, is confirmed to be alive, and then can be stopped and cleaned up properly. Also, I simulate receiving a message from a socket and assert that the provided callback function is correctly invoked with the message content.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK,I will take a look. thanks for your PR. Best regards.

Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def __init__(
):
"""Initialise the BinanceSocketManager"""
super().__init__()
self._loop: asyncio.AbstractEventLoop = get_loop() if _loop is None else _loop
self._loop: asyncio.AbstractEventLoop = None # Initialize as None, created in the run method
self._client: Optional[AsyncClient] = None
self._running: bool = True
self._socket_running: Dict[str, bool] = {}
Expand Down Expand Up @@ -62,7 +62,16 @@ async def start_listener(self, socket, path: str, callback):
del self._socket_running[path]

def run(self):
self._loop.run_until_complete(self.socket_listener())
# Create a new event loop for each thread
self._loop = asyncio.new_event_loop()
asyncio.set_event_loop(self._loop)

try:
# Run the event loop until completion
self._loop.run_until_complete(self.socket_listener())
finally:
# Ensure the event loop is closed when the thread ends
self._loop.close()

def stop_socket(self, socket_name):
if socket_name in self._socket_running:
Expand All @@ -79,6 +88,7 @@ def stop(self):
self._running = False
if self._client and self._loop and not self._loop.is_closed():
try:
# Use run_coroutine_threadsafe to execute coroutines in the event loop
future = asyncio.run_coroutine_threadsafe(
self.stop_client(), self._loop
)
Expand Down
79 changes: 79 additions & 0 deletions src/test/test_threaded_stream.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import asyncio
import time
import unittest
from unittest.mock import patch, MagicMock, AsyncMock

from src.gateway.binance.ws.threaded_stream import ThreadedApiManager


class TestThreadedApiManager(unittest.TestCase):
@patch("src.gateway.binance.ws.threaded_stream.AsyncClient.create", new_callable=AsyncMock)
def test_thread_lifecycle_and_loop_management(self, mock_async_client):
"""Tests the thread's startup, shutdown, and event loop cleanup."""
manager = ThreadedApiManager(api_key="test", api_secret="test")

self.assertIsNone(manager._loop)
self.assertTrue(manager._running)

manager.start()
time.sleep(0.1) # Give the thread time to start and create the loop

self.assertTrue(manager.is_alive())
self.assertIsNotNone(manager._loop)
self.assertTrue(manager._loop.is_running())

manager.stop()
manager.join(timeout=5) # Wait for the thread to terminate

self.assertFalse(manager.is_alive())
self.assertFalse(manager._running)
# After the thread stops, the loop should be closed.
self.assertTrue(manager._loop.is_closed())
mock_async_client.assert_called_once()

@patch("src.gateway.binance.ws.threaded_stream.AsyncClient.create", new_callable=AsyncMock)
def test_callback_is_called_on_message(self, mock_async_client):
"""Tests that the callback is invoked when a message is received."""
manager = ThreadedApiManager(api_key="test", api_secret="test")

# Mock the async context manager for the socket and its recv method
mock_socket = AsyncMock()
mock_recv = mock_socket.__aenter__.return_value.recv

# Setup the mock to simulate receiving one message
test_msg = {"data": "test_message"}
mock_recv.return_value = test_msg

mock_callback = MagicMock()
path = "test_path"

# This wrapper callback will call the mock and then stop the listener loop
def callback_with_side_effect(msg):
mock_callback(msg)
manager.stop_socket(path)

manager._socket_running[path] = True

async def test_runner():
# This coroutine will be run in the manager's event loop
await manager.start_listener(mock_socket, path, callback_with_side_effect)

manager.start()
# Wait until the loop is actually running to avoid a race condition
while not (manager._loop and manager._loop.is_running()):
time.sleep(0.01)

# Run the test coroutine in the thread's event loop
future = asyncio.run_coroutine_threadsafe(test_runner(), manager._loop)
future.result(timeout=5) # Wait for completion

manager.stop()
manager.join()

# Assertions
mock_callback.assert_called_once_with(test_msg)
self.assertFalse(path in manager._socket_running)


if __name__ == "__main__":
unittest.main()
Loading