diff --git a/setup.cfg b/setup.cfg index ed748aa..4450a45 100644 --- a/setup.cfg +++ b/setup.cfg @@ -34,11 +34,7 @@ install_requires = apscheduler cachetools -# The usage of test_requires is discouraged, see `Dependency Management` docs -tests_require = - pytest - pytest-cov - pytest-asyncio + # Require a specific Python version, e.g. Python 2.7 or >= 3.4 python_requires = >=3.8 @@ -56,6 +52,7 @@ testing = pytest pytest-cov pytest-asyncio + aiomisc-pytest [options.entry_points] # Add here console scripts like: @@ -90,6 +87,8 @@ testpaths = tests # markers = # slow: mark tests as slow (deselect with '-m "not slow"') +asyncio_default_fixture_loop_scope = function + [aliases] dists = bdist_wheel diff --git a/src/amqp_fabric/amq_broker_connector.py b/src/amqp_fabric/amq_broker_connector.py index 1920064..357a90a 100644 --- a/src/amqp_fabric/amq_broker_connector.py +++ b/src/amqp_fabric/amq_broker_connector.py @@ -57,11 +57,17 @@ def deserialize(self, data: Any) -> bytes: return super().deserialize(gzip.decompress(data)) +BROKER_RECONNECT_RETRY_DELAY = os.environ.get( + "AMQFAB_BROKER_RECONNECT_RETRY_DELAY", 5.0 +) +BROKER_HEARTBEAT = os.environ.get("AMQFAB_BROKER_HEARTBEAT", 60) MSG_TYPE_KEEP_ALIVE = "keep_alive" -MAX_DISCOVERY_CACHE_ENTRIES = os.environ.get("MAX_DISCOVERY_CACHE_ENTRIES", 100) -DISCOVERY_CACHE_TTL = os.environ.get("DISCOVERY_CACHE_TTL", 5) -DATA_EXCHANGE_NAME = os.environ.get("DATA_EXCHANGE_NAME", "data") -DISCOVERY_EXCHANGE_NAME = os.environ.get("DISCOVERY_EXCHANGE_NAME", "msc.discovery") +MAX_DISCOVERY_CACHE_ENTRIES = os.environ.get("AMQFAB_MAX_DISCOVERY_CACHE_ENTRIES", 100) +DISCOVERY_CACHE_TTL = os.environ.get("AMQFAB_DISCOVERY_CACHE_TTL", 5) +DATA_EXCHANGE_NAME = os.environ.get("AMQFAB_DATA_EXCHANGE_NAME", "data") +DISCOVERY_EXCHANGE_NAME = os.environ.get( + "AMQFAB_DISCOVERY_EXCHANGE_NAME", "msc.discovery" +) REGEX_FQN_PATTERN = r"^(?:[A-Za-z0-9-_]{1,63}\.){1,255}[A-Za-z0-9-_]{1,63}$" @@ -126,6 +132,8 @@ def __init__( self._keepalive_subscriber_service_type = None self._keepalive_subscriber_service_id = None + self._api = None + @property def domain(self): return self._service_domain @@ -146,22 +154,42 @@ def data_exchange(self): def fqn(self): return broker_fqn(self._service_domain, self._service_type, self._service_id) + async def _on_reconnect(self, connection=None): + try: + # This will create the exchange if it doesn't already exist. + channel = await self._broker_conn.channel() + + self._data_exchange = await channel.declare_exchange( + name=self._data_exchange_name, type=ExchangeType.HEADERS, durable=True + ) + self._discovery_exchange = await channel.declare_exchange( + name=self._discovery_exchange_name, + type=ExchangeType.HEADERS, + durable=True, + ) + + log.info(f"Service '{self.fqn}' connected to broker.") + + if self._api: + await self.rpc_register(self._api) + + except Exception as e: + log.error("Error reconnecting....") + log.error(e) + async def open(self, **kwargs: Any): self._broker_conn = await connect_robust( url=self._amqp_uri, client_properties={"connection_name": "rpc_srv"}, + connection_attempts=None, # None means infinite retries + retry_delay=BROKER_RECONNECT_RETRY_DELAY, # wait 5 s between attempts + heartbeat=BROKER_HEARTBEAT, # send heartbeats every minute **kwargs, ) - # This will create the exchange if it doesn't already exist. - channel = await self._broker_conn.channel() + self._broker_conn.reconnect_callbacks.add(self._on_reconnect) - self._data_exchange = await channel.declare_exchange( - name=self._data_exchange_name, type=ExchangeType.HEADERS, durable=True - ) - self._discovery_exchange = await channel.declare_exchange( - name=self._discovery_exchange_name, type=ExchangeType.HEADERS, durable=True - ) + await self._on_reconnect() await aio.sleep(0.1) # Initialize keep-alive messages @@ -183,6 +211,8 @@ async def open(self, **kwargs: Any): # Initialize keep-alive listener if self._keep_alive_listen: + channel = await self._broker_conn.channel() + self._discovery_cache = TTLCache( maxsize=MAX_DISCOVERY_CACHE_ENTRIES, ttl=self._discovery_cache_ttl ) @@ -203,11 +233,14 @@ async def close(self): self._scheduler.shutdown(wait=True) self._scheduler = None + self._api = None await self._broker_conn.close() # --- Service management routines --- async def rpc_register(self, api): + self._api = api + # Creating channel channel = await self._broker_conn.channel() await channel.set_qos(prefetch_count=1) @@ -225,9 +258,7 @@ async def rpc_register(self, api): await rpc.register(api_name, awaitify(callee), auto_delete=True) log.info( - 'RPC Server Registered on Exchange "{}"'.format( - self._rpc_server_exchange_name - ) + f'RPC Server Registered on Exchange "{self._rpc_server_exchange_name}"' ) async def rpc_proxy(self, service_domain, service_id, service_type): @@ -311,19 +342,21 @@ async def subscribe_data(self, subscriber_name, headers, callback): await queue.consume(callback) async def _on_send_keep_alive(self): - try: - headers = { - "msg_type": MSG_TYPE_KEEP_ALIVE, - "service_domain": self._service_domain, - "service_id": self._service_id, - "service_type": self._service_type, - } - - aio.create_task( - self._discovery_exchange.publish( - message=Message(body="".encode(), headers=headers), routing_key="" - ) + headers = { + "msg_type": MSG_TYPE_KEEP_ALIVE, + "service_domain": self._service_domain, + "service_id": self._service_id, + "service_type": self._service_type, + } + + task = aio.create_task( + self._discovery_exchange.publish( + message=Message(body="".encode(), headers=headers), routing_key="" ) + ) + + try: + await task # Exception is raised here except Exception as e: log.error(e) @@ -359,7 +392,8 @@ async def _on_get_keep_alive(self, message: IncomingMessage): or headers["service_id"] == self._keepalive_service_service_id ) ): - aio.create_task(self._keepalive_subscriber_callback(headers)) + task = aio.create_task(self._keepalive_subscriber_callback(headers)) + await task # Exception is raised here except Exception as e: log.error(e) diff --git a/tests/conftest.py b/tests/conftest.py index 8493fd5..c250c61 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -9,12 +9,33 @@ """ import os +from typing import Type -AMQP_URL = os.environ.get("AMQP_URL", "amqp://guest:guest@localhost/") +import pytest +from aiomisc_pytest import TCPProxy + +AMQP_HOST = os.environ.get("AMQP_HOST", "localhost") +AMQP_PORT = os.environ.get("AMQP_PORT", "5672") +AMQP_URL = os.environ.get("AMQP_URL", f"amqp://guest:guest@{AMQP_HOST}:{AMQP_PORT}/") SERVICE_ID = os.environ.get("SERVICE_ID", "amqp-fabric") SERVICE_TYPE = os.environ.get("SERVICE_TYPE", "no-type") SERVICE_DOMAIN = os.environ.get("SERVICE_DOMAIN", "some-domain") RPC_EXCHANGE_NAME = os.environ.get( "RPC_EXCHANGE_NAME", f"{SERVICE_DOMAIN}.api.{SERVICE_TYPE}.{SERVICE_ID}" ) -DATA_EXCHANGE_NAME = os.environ.get("DATA_EXCHANGE_NAME", f"{SERVICE_DOMAIN}.daq.data") +DATA_EXCHANGE_NAME = os.environ.get("DATA_EXCHANGE_NAME", f"{SERVICE_DOMAIN}.data") + + +@pytest.fixture +async def proxy(tcp_proxy: Type[TCPProxy]): + p = tcp_proxy( + AMQP_HOST, + AMQP_PORT, + buffered=False, + ) + + await p.start() + try: + yield p + finally: + await p.close() diff --git a/tests/test_amqp_broker_connector.py b/tests/test_amqp_broker_connector.py index 619c7dc..7c4b996 100644 --- a/tests/test_amqp_broker_connector.py +++ b/tests/test_amqp_broker_connector.py @@ -2,10 +2,12 @@ import datetime as dt import json +import aiomisc import pytest from aio_pika import IncomingMessage, connect_robust -from aio_pika.exceptions import MessageProcessError +from aio_pika.exceptions import CONNECTION_EXCEPTIONS, MessageProcessError from aio_pika.patterns.rpc import JsonRPCError +from aiomisc_pytest import TCPProxy from conftest import ( AMQP_URL, RPC_EXCHANGE_NAME, @@ -315,3 +317,68 @@ async def on_new_data(message: IncomingMessage): await srv_conn.close() await client_conn.close() + + +@aiomisc.timeout(30) +@pytest.mark.asyncio +async def test_server_reconnects(proxy: TCPProxy): + api = TestApi() + + amqp_url = f"amqp://guest:guest@{proxy.proxy_host}:{proxy.proxy_port}/" + reconnect_event = asyncio.Event() + + srv_conn = AmqBrokerConnector( + amqp_uri=amqp_url, + service_domain=SERVICE_DOMAIN, + service_type=SERVICE_TYPE, + service_id=SERVICE_ID, + keep_alive_seconds=2, + ) + await srv_conn.open() + + srv_conn._broker_conn.reconnect_callbacks.add( + lambda *_: reconnect_event.set(), + ) + + assert srv_conn.fqn == f"{SERVICE_DOMAIN}.{SERVICE_TYPE}.{SERVICE_ID}" + assert srv_conn.service_id == SERVICE_ID + assert srv_conn.service_type == SERVICE_TYPE + assert srv_conn.domain == SERVICE_DOMAIN + assert srv_conn.data_exchange == f"{SERVICE_DOMAIN}.data" + + # Init server + await srv_conn.rpc_register(api) + + # Init client + client_conn = AmqBrokerConnector( + amqp_uri=AMQP_URL, + service_domain=SERVICE_DOMAIN, + service_type="client", + service_id="client", + ) + await client_conn.open() + + rpc_proxy = await client_conn.rpc_proxy( + service_domain=SERVICE_DOMAIN, + service_id=SERVICE_ID, + service_type=SERVICE_TYPE, + ) + + assert await rpc_proxy.multiply(x=100, y=2) + assert srv_conn._scheduler + + # Disconnect existing client + await proxy.disconnect_all() + + with pytest.raises(CONNECTION_EXCEPTIONS): + await rpc_proxy.multiply(x=100, y=2) + + # Wait for reconnect + await asyncio.wait_for(reconnect_event.wait(), timeout=10) + # + # # Test RPC again + # await rpc_proxy.multiply(x=100, y=2) + # + await client_conn.close() + + await srv_conn._broker_conn.close()