From 6c12e2006ab2fccad8f69142f4838bf54f048b33 Mon Sep 17 00:00:00 2001 From: Patrick Gingras <775.pg.12@gmail.com> Date: Mon, 7 Dec 2020 21:10:33 -0500 Subject: [PATCH] initial commit --- .flake8 | 2 + .github/CODEOWNERS | 1 + .github/workflows/build_deploy.yml | 69 +++++ .github/workflows/ci.yml | 46 ++++ .gitignore | 5 + LICENSE | 29 ++ README.md | 121 +++++++++ asyncio_connection_pool/__init__.py | 165 +++++++++++ asyncio_connection_pool/contrib/__init__.py | 0 asyncio_connection_pool/contrib/aioredis.py | 19 ++ asyncio_connection_pool/contrib/datadog.py | 135 +++++++++ mypy.ini | 18 ++ riotfile.py | 45 +++ setup.py | 35 +++ test/test_pool.py | 287 ++++++++++++++++++++ 15 files changed, 977 insertions(+) create mode 100644 .flake8 create mode 100644 .github/CODEOWNERS create mode 100644 .github/workflows/build_deploy.yml create mode 100644 .github/workflows/ci.yml create mode 100644 .gitignore create mode 100644 LICENSE create mode 100644 README.md create mode 100644 asyncio_connection_pool/__init__.py create mode 100644 asyncio_connection_pool/contrib/__init__.py create mode 100644 asyncio_connection_pool/contrib/aioredis.py create mode 100644 asyncio_connection_pool/contrib/datadog.py create mode 100644 mypy.ini create mode 100644 riotfile.py create mode 100644 setup.py create mode 100644 test/test_pool.py diff --git a/.flake8 b/.flake8 new file mode 100644 index 0000000..f4546ad --- /dev/null +++ b/.flake8 @@ -0,0 +1,2 @@ +[flake8] +max_line_length = 88 diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS new file mode 100644 index 0000000..4a4e855 --- /dev/null +++ b/.github/CODEOWNERS @@ -0,0 +1 @@ +* @p7g diff --git a/.github/workflows/build_deploy.yml b/.github/workflows/build_deploy.yml new file mode 100644 index 0000000..e2f9a5b --- /dev/null +++ b/.github/workflows/build_deploy.yml @@ -0,0 +1,69 @@ +name: Build + +on: + pull_request: + release: + types: + - published + +jobs: + build_wheel: + name: Build wheel + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v2 + # Include all history and tags + with: + fetch-depth: 0 + + - uses: actions/setup-python@v2 + name: Install Python + with: + python-version: '3.8' + + - name: Build wheel + run: | + python -m pip install wheel + python -m pip wheel -w dist . + + - uses: actions/upload-artifact@v2 + with: + path: dist/*.whl + + build_sdist: + name: Build source distribution + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + # Include all history and tags + with: + fetch-depth: 0 + + - uses: actions/setup-python@v2 + name: Install Python + with: + python-version: '3.8' + + - name: Build sdist + run: | + python setup.py sdist + + - uses: actions/upload-artifact@v2 + with: + path: dist/*.tar.gz + + upload_pypi: + needs: [build_wheel, build_sdist] + runs-on: ubuntu-latest + if: github.event_name == 'release' && github.event.action == 'published' + steps: + - uses: actions/download-artifact@v2 + with: + name: artifact + path: dist + + - uses: pypa/gh-action-pypi-publish@master + with: + user: __token__ + password: ${{ secrets.PYPI_TOKEN }} diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..364cb07 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,46 @@ +name: CI +on: [push, pull_request] +jobs: + black: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + - uses: actions/setup-python@v2 + with: + python-version: '3.9' + - run: pip install riot==0.4.0 + - run: riot -v run -s black -- --check . + mypy: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + - uses: actions/setup-python@v2 + with: + python-version: '3.9' + - run: pip install riot==0.4.0 + - run: riot -v run mypy + flake8: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + - uses: actions/setup-python@v2 + with: + python-version: '3.9' + - run: pip install riot==0.4.0 + - run: riot -v run flake8 + test: + strategy: + matrix: + os: [ubuntu-latest, macos-latest] + python-version: [3.8, 3.9] + runs-on: ${{ matrix.os }} + steps: + - uses: actions/checkout@v2 + - name: Setup Python + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + - name: install riot + run: pip install riot==0.4.0 + - name: run tests + run: riot -v run --python=${{ matrix.python-version }} test diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..08974ef --- /dev/null +++ b/.gitignore @@ -0,0 +1,5 @@ +__pycache__/ +.eggs/ +*.egg-info/ +.mypy_cache/ +.riot diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..d071125 --- /dev/null +++ b/LICENSE @@ -0,0 +1,29 @@ +BSD 3-Clause License + +Copyright (c) 2019, Fellow Insights Inc. +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/README.md b/README.md new file mode 100644 index 0000000..a61ab33 --- /dev/null +++ b/README.md @@ -0,0 +1,121 @@ +# `asyncio-connection-pool` + +This is a generic, high-throughput, optionally-burstable pool for asyncio. + +Some cool features: + +- No locking aside from the GIL; no `asyncio.Lock` or `asyncio.Condition` needs + to be taken in order to get a connection. +- Available connections are retrieved without yielding to the event loop. +- When `burst_limit` is specified, `max_size` acts as a "soft" limit; the pool + can go beyond this limit to handle increased load, and shrinks back down + after. +- The contents of the pool can be anything; just implement a + `ConnectionStrategy`. + + +## Why? + +We were using a different pool for handling our Redis connections, and noticed +that, under heavy load, we would spend a lot of time waiting for the lock, even +when there were available connections in the pool. + +We also thought it would be nice if we didn't need to keep many connections +open when they weren't needed, but still have the ability to more when they are +required. + + +## API + + +### `asyncio_connection_pool.ConnectionPool` + +This is the implementation of the pool. It is generic over a type of +connection, and all implementation-specific logic is contained within a +[`ConnectionStrategy`](#connectionstrategy). + +A pool is created as follows: + +```python +from asyncio_connection_pool import ConnectionPool + +pool = ConnectionPool(strategy=my_strategy, max_size=15) +``` + +The constructor can optionally be passed an integer as `burst_limit`. This +allows the pool to open more connections than `max_size` temporarily. + + +#### `@asynccontextmanager async def get_connection(self) -> AsyncIterator[Conn]` + +This method is the only way to get a connection from the pool. It is expected +to be used as follows: + +```python +pool = ConnectionPool(...) + +async with pool.get_connection() as conn: + # Use the connection + pass +``` + +When the `async with` block is entered, a connection is retrieved. If a +connection needs to be opened or if the pool is at capacity and no connections +are available, the caller will yield to the event loop. + +When the block is exited, the connection will be returned to the pool. + + +### `asyncio_connection_pool.ConnectionStrategy` + +This is an abstract class that defines the interface of the object passed as +`strategy`. A subclass _must_ implement the following methods: + + +#### `async def create_connection(self) -> Awaitable[Conn]` + +This method is called to create a new connection to the resource. This happens +when a connection is requested and all connections are in use, as long as the +pool is not at capacity. + +The result of a call to this method is what will be provided to a consumer of +the pool, and in most cases will be stored in the pool to be re-used later. + +If this method raises an exception, it will bubble up to the frame where +`ConnectionPool.get_connection()` was called. + + +#### `def connection_is_closed(self, conn: Conn) -> bool` + +This method is called to check if a connection is no longer able to be used. +When the pool is retrieving a connection to give to a client, this method is +called to make sure it is valid. + +The return value should be `True` if the connection is _not_ valid. + +If this method raises an exception, it is assumed that the connection is +invalid. The passed-in connection is dropped and a new one is retrieved. The +exception is suppressed unless it is not a `BaseException`, like +`asyncio.CancelledError`. It is the responsibility of the `ConnectionStrategy` +implementation to avoid leaking a connection in this case. + + +#### `def close_connection(self, conn: Conn)` + +This method is called to close a connection. This occurs when the pool has +exceeded `max_size` (i.e. it is bursting) and a connection is returned that is +no longer needed (i.e. there are no more consumers waiting for a connection). + +Note that this method is synchronous; if closing a connection is an +asynchronous operation, `asyncio.create_task` can be used. + +If this method raises an exception, the connection is dropped and the exception +bubbles to the caller of `ConnectionPool.get_connection().__aexit__` (usually +an `async with` block). + + +## How is this safe without locks? + +I encourage you to read the [source](https://github.com/fellowinsights/asyncio-connection-pool/blob/master/asyncio_connection_pool/__init__.py) +to find out (it is quite well-commented). If you notice any faults in the +logic, please feel free to file an issue. diff --git a/asyncio_connection_pool/__init__.py b/asyncio_connection_pool/__init__.py new file mode 100644 index 0000000..06450cd --- /dev/null +++ b/asyncio_connection_pool/__init__.py @@ -0,0 +1,165 @@ +from abc import ABC, abstractmethod +import asyncio +from contextlib import asynccontextmanager +from typing import AsyncIterator, Awaitable, Generic, Optional, TypeVar + +__all__ = "ConnectionPool", "ConnectionStrategy" +Conn = TypeVar("Conn") + + +class ConnectionStrategy(ABC, Generic[Conn]): + @abstractmethod + async def make_connection(self) -> Awaitable[Conn]: + ... + + @abstractmethod + def connection_is_closed(self, conn: Conn) -> bool: + ... + + @abstractmethod + def close_connection(self, conn: Conn) -> None: + ... + + +class ConnectionPool(Generic[Conn]): + """A high-throughput, optionally-burstable pool free of explicit locking. + + NOTE: Not threadsafe. Do not share across threads. + + This threadpool offers high throughput by avoiding the need for an explicit + lock to retrieve a connection. This is possible by taking advantage of + Pythons global interpreter lock (GIL). + + If the optional `burst_limit` argument is supplied, the `max_size` argument + will act as a "soft" maximum. When there is demand, more connections will + be opened to satisfy it, up to `burst_limit`. When these connections are no + longer needed, they will be closed. This way we can avoid holding many open + connections for extended times. + + Since we make use of the GIL, this pool should not be shared across + threads. This is unsafe because some C extensions release the GIL when + waiting on IO, for example. In that case, a different thread can actually + execute concurrently, which breaks the assumptions upon which this pool is + based. + + This pool is generic over the type of connection it holds, which can be + anything. Any implementation dependent logic belongs in the + ConnectionStrategy, which should be passed to the pool's constructor via + the `strategy` parameter. + """ + + def __init__( + self, + *, + strategy: ConnectionStrategy[Conn], + max_size: int, + burst_limit: Optional[int] = None + ): + self._loop = asyncio.get_event_loop() + self.strategy = strategy + self.max_size = max_size + self.burst_limit = burst_limit + if burst_limit is not None and burst_limit < max_size: + raise ValueError("burst_limit must be greater than or equal to max_size") + self.in_use = 0 + self.currently_allocating = 0 + self.available: "asyncio.Queue[Conn]" = asyncio.Queue(maxsize=self.max_size) + + @property + def _total(self) -> int: + return self.in_use + self.currently_allocating + self.available.qsize() + + @property + def _waiters(self) -> int: + waiters = self.available._getters # type: ignore + return sum(not (w.done() or w.cancelled()) for w in waiters) + + async def _connection_maker(self): + try: + conn = await self.strategy.make_connection() + finally: + self.currently_allocating -= 1 + self.in_use += 1 + return conn + + async def _connection_waiter(self): + conn = await self.available.get() + self.in_use += 1 + return conn + + def _get_conn(self) -> "Awaitable[Conn]": + # This function is how we avoid explicitly locking. Since it is + # synchronous, we do all the "book-keeping" required to get a + # connection synchronously (i.e. implicitly holding the GIL), and + # return a Future or Task which can be awaited after this function + # returns. + # + # The most important thing here is that we have the GIL from when we + # measure values like `self._total` or `self.available.empty()` until + # we change values that affect those measurements. In other words, + # taking a connection must be an atomic operation. + if not self.available.empty(): + # Reserve a connection and wrap in a Future to make it awaitable. + # Incidentally, awaiting a done Future doesn't involve yielding to + # the event loop; it's more like getting the next value from a + # generator. + fut: "asyncio.Future[Conn]" = self._loop.create_future() + fut.set_result(self.available.get_nowait()) + self.in_use += 1 + return fut + elif self._total < self.max_size or ( + self.burst_limit is not None and self._total < self.burst_limit + ): + # Reserve a space for a connection and asynchronously make it. + # Returns a Task that resolves to the new connection, which can be + # awaited. + # + # If there are a lot of threads waiting for a connection, to avoid + # having all of them time out and be cancelled, we'll burst to + # higher max_size. + self.currently_allocating += 1 + return self._loop.create_task(self._connection_maker()) + else: + # Return a Task that waits for the next task to appear in the queue. + return self._loop.create_task(self._connection_waiter()) + + @asynccontextmanager + async def get_connection(self) -> AsyncIterator[Conn]: # type: ignore + # _get_conn atomically does any book-keeping and returns an awaitable + # that resolves to a connection. + conn = await self._get_conn() + # Repeat until the connection we get is still open. + while True: + try: + if not self.strategy.connection_is_closed(conn): + break + except BaseException: + self.in_use -= 1 + raise + self.in_use -= 1 # Incremented in _get_conn + conn = await self._get_conn() + + try: + # Evaluate the body of the `async with` block. + yield conn + finally: + # Return the connection to the pool. + self.in_use -= 1 + assert self.in_use >= 0, "More connections returned than given" + + try: + # Check if we are currently over-committed (i.e. bursting) + if self._total >= self.max_size and self._waiters == 0: + # We had created extra connections to handle burst load, + # but there are no more waiters, so we don't need this + # connection anymore. + self.strategy.close_connection(conn) + else: + self.available.put_nowait(conn) + except asyncio.QueueFull: + # We don't actually check if the queue has room before trying + # to put the connection into it. It's unclear whether we could + # have a full queue and still have waiters, but we should + # handle this case to be safe (otherwise we would leak + # connections). + self.strategy.close_connection(conn) diff --git a/asyncio_connection_pool/contrib/__init__.py b/asyncio_connection_pool/contrib/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/asyncio_connection_pool/contrib/aioredis.py b/asyncio_connection_pool/contrib/aioredis.py new file mode 100644 index 0000000..e6cb2e7 --- /dev/null +++ b/asyncio_connection_pool/contrib/aioredis.py @@ -0,0 +1,19 @@ +import aioredis +from functools import partial +from asyncio_connection_pool import ConnectionStrategy + +__all__ = ("RedisConnectionStrategy",) + + +class RedisConnectionStrategy(ConnectionStrategy[aioredis.Redis]): # type: ignore + def __init__(self, *args, **kwargs): + self._create_redis = partial(aioredis.create_redis, *args, **kwargs) + + async def make_connection(self): + return await self._create_redis() + + def connection_is_closed(self, conn): + return conn.closed + + def close_connection(self, conn): + conn.close() diff --git a/asyncio_connection_pool/contrib/datadog.py b/asyncio_connection_pool/contrib/datadog.py new file mode 100644 index 0000000..f948a6c --- /dev/null +++ b/asyncio_connection_pool/contrib/datadog.py @@ -0,0 +1,135 @@ +from contextlib import asynccontextmanager, AsyncExitStack +from datadog import statsd +from ddtrace import tracer +from typing import AsyncIterator, TypeVar +from asyncio_connection_pool import ConnectionPool as _ConnectionPool + +__all__ = ("ConnectionPool",) +Conn = TypeVar("Conn") + + +class ConnectionPool(_ConnectionPool[Conn]): + def __init__(self, service_name, *args, extra_tags=None, **kwargs): + super().__init__(*args, **kwargs) + self._connections_acquiring = 0 + self._service_name = service_name + self._is_bursting = False + self._reported_hitting_burst_limit = False + self._extra_tags = extra_tags or [] + self._loop.call_soon(self._periodically_send_metrics) + + def _periodically_send_metrics(self): + try: + self._record_pressure() + finally: + self._loop.call_later(60, self._periodically_send_metrics) + + def _record_pressure(self): + statsd.gauge( + f"{self._service_name}.pool.total_connections", + self._total, + tags=self._extra_tags, + ) + statsd.gauge( + f"{self._service_name}.pool.available_connections", + self.available.qsize(), + tags=self._extra_tags, + ) + statsd.gauge( + f"{self._service_name}.pool.waiting", self._waiters, tags=self._extra_tags + ) + statsd.gauge( + f"{self._service_name}.pool.connections_used", + self.in_use, + tags=self._extra_tags, + ) + self._record_connection_acquiring() + if self._total > self.max_size: + if not self._is_bursting: + self._is_bursting = True + statsd.event( + f"{self._service_name} pool using burst capacity", + f"Pool max size of {self.max_size} will be exceeded temporarily, up to {self.burst_limit}", # noqa E501 + alert_type="warning", + tags=self._extra_tags, + ) + elif self._is_bursting: + self._is_bursting = False + self._reported_hitting_burst_limit = False + statsd.event( + f"{self._service_name} pool no longer bursting", + f"Number of connections has dropped below {self.max_size}", + alert_type="success", + tags=self._extra_tags, + ) + if self._total == self.burst_limit: + self._reported_hitting_burst_limit = True + statsd.event( + f"{self._service_name} pool reached burst limit", + "There are not enough redis connections to satisfy all users", + alert_type="error", + tags=self._extra_tags, + ) + + def _record_connection_acquiring(self, value=0): + self._connections_acquiring += value + + statsd.gauge( + f"{self._service_name}.pool.connections_acquiring", + self._connections_acquiring, + tags=self._extra_tags, + ) + + def _connection_maker(self): + statsd.increment( + f"{self._service_name}.pool.getting_connection", + tags=self._extra_tags + ["method:new"], + ) + + async def connection_maker(self): + with tracer.trace( + f"{self._service_name}.pool._create_new_connection", + service=self._service_name, + ): + return await super()._connection_maker() + + return connection_maker(self) + + def _connection_waiter(self): + statsd.increment( + f"{self._service_name}.pool.getting_connection", + tags=self._extra_tags + ["method:wait"], + ) + + async def connection_waiter(self): + with tracer.trace( + f"{self._service_name}.pool._wait_for_connection", + service=self._service_name, + ): + return await super()._connection_waiter() + + return connection_waiter(self) + + def _get_conn(self): + if not self.available.empty(): + statsd.increment( + f"{self._service_name}.pool.getting_connection", + tags=self._extra_tags + ["method:available"], + ) + return super()._get_conn() + + @asynccontextmanager + async def get_connection(self) -> AsyncIterator[Conn]: # type: ignore + async with AsyncExitStack() as stack: + self._record_connection_acquiring(1) + try: + with tracer.trace( + f"{self._service_name}.pool.acquire_connection", + service=self._service_name, + ): + conn = await stack.enter_async_context(super().get_connection()) + finally: + self._record_connection_acquiring(-1) + self._record_pressure() + yield conn + self._record_pressure() diff --git a/mypy.ini b/mypy.ini new file mode 100644 index 0000000..256909f --- /dev/null +++ b/mypy.ini @@ -0,0 +1,18 @@ +[mypy] +ignore_missing_imports = True +check_untyped_defs = True +disallow_any_unimported = True +disallow_any_decorated = True +disallow_any_generics = True +disallow_subclassing_any = True +disallow_incomplete_defs = False +disallow_untyped_decorators = True +no_implicit_optional = True +strict_optional = True +warn_redundant_casts = True +warn_unused_ignores = True +warn_return_any = True +warn_unreachable = True +implicit_reexport = False +strict_equality = True +pretty = True diff --git a/riotfile.py b/riotfile.py new file mode 100644 index 0000000..44d053d --- /dev/null +++ b/riotfile.py @@ -0,0 +1,45 @@ +from riot import Venv, latest + +venv = Venv( + pys=3, + venvs=[ + Venv( + pys=[3.8, 3.9], + name="test", + command="pytest {cmdargs}", + pkgs={ + "pytest": "==6.1.2", + "pytest-asyncio": "==0.14.0", + # extras_require + "ddtrace": latest, + "datadog": latest, + "aioredis": latest, + }, + ), + Venv( + name="mypy", + command="mypy asyncio_connection_pool", + pkgs={ + "mypy": "==0.790", + }, + ), + Venv( + pkgs={"black": "==20.8b1"}, + venvs=[ + Venv( + name="fmt", + command=r"black --exclude '/\.riot/' .", + ), + Venv( + name="black", + command=r"black --exclude '/\.riot/' {cmdargs}", + ), + ], + ), + Venv( + name="flake8", + pkgs={"flake8": "==3.8.4"}, + command="flake8 test asyncio_connection_pool", + ), + ], +) diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..fb92d9e --- /dev/null +++ b/setup.py @@ -0,0 +1,35 @@ +from setuptools import setup, find_packages + +with open("README.md", "r") as f: + long_description = f.read() + +setup( + name="asyncio-connection-pool", + description="A high-throughput, optionally-burstable pool free of explicit locking", + url="https://github.com/fellowinsights/asyncio-connection-pool", + author="Patrick Gingras <775.pg.12@gmail.com>", + author_email="775.pg.12@gmail.com", + classifiers=[ + "Programming Language :: Python", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "License :: OSI Approved :: BSD License", + ], + keywords="asyncio", + long_description=long_description, + long_description_content_type="text/markdown", + packages=find_packages( + include=["asyncio_connection_pool, asyncio_connection_pool.*"] + ), + package_data={"asyncio-connection-pool": ["py.typed"]}, + python_requires=">=3.8", + install_requires=[], + tests_require=["riot"], + extras_require={ + "datadog": ["ddtrace", "datadog"], + "aioredis": ["aioredis"], + }, + setup_requires=["setuptools_scm"], + use_scm_version=True, + zip_safe=False, # for mypy support +) diff --git a/test/test_pool.py b/test/test_pool.py new file mode 100644 index 0000000..0428483 --- /dev/null +++ b/test/test_pool.py @@ -0,0 +1,287 @@ +import asyncio +import pytest +from asyncio_connection_pool import ConnectionPool, ConnectionStrategy +from asyncio_connection_pool.contrib.datadog import ( + ConnectionPool as TracingConnectionPool, +) +from contextlib import asynccontextmanager +from functools import partial + + +@pytest.fixture( + params=[ConnectionPool, partial(TracingConnectionPool, service_name="test")] +) +def pool_cls(request): + return request.param + + +class RandomIntStrategy(ConnectionStrategy): + async def make_connection(self): + import random + + return random.randint(0, 10000) + + def connection_is_closed(self, conn): + return False + + def close_connection(self, conn): + pass + + +def test_valid_burst_limit(pool_cls): + """Test that invalid burst_limit values cause errors (only at construction time)""" + strategy = RandomIntStrategy() + pool_cls(strategy=strategy, max_size=100, burst_limit=None) + pool_cls(strategy=strategy, max_size=100, burst_limit=100) + pool_cls(strategy=strategy, max_size=100, burst_limit=101) + with pytest.raises(ValueError): + pool_cls(strategy=strategy, max_size=100, burst_limit=99) + + +class Counter: + def __init__(self, goal): + self.goal = goal + self.n = 0 + self.ev = asyncio.Event() + # Prevent waiting more than once on the same counter + self._waiter = self.ev.wait() + + def wait(self): + return self._waiter + + @asynccontextmanager + async def inc(self): + self.n += 1 + if self.n == self.goal: + self.ev.set() + try: + yield self.n + finally: + self.n -= 1 + + +@pytest.mark.asyncio +async def test_concurrent_get_connection(pool_cls): + """Test handling several connection requests in a short time. (Not truly + concurrent because of the GIL) + """ + + pool = pool_cls(strategy=RandomIntStrategy(), max_size=20) + nworkers = 10 + counter = Counter(nworkers) + stop = asyncio.Event() + + async def connection_holder(): + async with pool.get_connection(): + async with counter.inc(): + await stop.wait() + + coros = [asyncio.create_task(connection_holder()) for _ in range(nworkers)] + await counter.wait() + + assert pool.in_use == nworkers, f"{nworkers} connections should be in use" + assert pool.available.empty(), "There should not be any extra connections" + + stop.set() + await asyncio.gather(*coros) + + assert pool.in_use == 0 + assert ( + pool.available.qsize() == nworkers + ), f"{nworkers} connections should be allocated" + + +@pytest.mark.asyncio +async def test_currently_allocating(pool_cls): + """Test that currently_allocating is accurate.""" + + ev = asyncio.Event() + + class WaitStrategy(ConnectionStrategy): + async def make_connection(self): + await ev.wait() + + def connection_is_closed(self, conn): + return False + + def close_connection(self, conn): + pass + + nworkers = 10 + pool = pool_cls(strategy=WaitStrategy(), max_size=50) + counter = Counter(nworkers) + counter2 = Counter(nworkers) + ev2 = asyncio.Event() + + async def worker(): + async with counter.inc(): + async with pool.get_connection(): + async with counter2.inc(): + await ev2.wait() + + coros = [asyncio.create_task(worker()) for _ in range(nworkers)] + await counter.wait() + await asyncio.sleep(0) + + assert ( + pool.currently_allocating == nworkers + ), f"{nworkers} workers are waiting for a connection" + ev.set() # allow the workers to get their connections + await counter2.wait() + assert ( + pool.currently_allocating == 0 and pool.in_use == nworkers + ), "all workers should have their connections now" + ev2.set() + await asyncio.gather(*coros) + assert ( + pool.in_use == 0 and pool.available.qsize() == nworkers + ), "all workers should have returned their connections" + + +@pytest.mark.asyncio +async def test_burst(pool_cls): + """Test that bursting works when enabled and doesn't when not.""" + + did_call_close_connection = asyncio.Event() + + class Strategy(RandomIntStrategy): + def close_connection(self, conn): + did_call_close_connection.set() + return super().close_connection(conn) + + # Burst disabled initially + pool = pool_cls(strategy=Strategy(), max_size=5) + + async def worker(counter, ev): + async with pool.get_connection(): + async with counter.inc(): + await ev.wait() # hold the connection until we say so + + # Use up the normal max_size of the pool + main_event = asyncio.Event() + counter = Counter(pool.max_size) + coros = [ + asyncio.create_task(worker(counter, main_event)) for _ in range(pool.max_size) + ] + await counter.wait() + + with pytest.raises(asyncio.TimeoutError): + # Burst is disabled, can't get a connection without waiting (this is a + # deadlock) + await asyncio.wait_for(worker(counter, main_event), timeout=0.25) + + # Add a burst size to the pool + pool.burst_limit = pool.max_size + 1 + counter = Counter(goal=1) + burst_event = asyncio.Event() + burst_worker = asyncio.create_task(worker(counter, burst_event)) + await counter.wait() + assert pool._total == pool.max_size + 1 + + with pytest.raises(asyncio.TimeoutError): + # We're at burst_limit, can't get a connection without waiting + await asyncio.wait_for(worker(counter, burst_event), timeout=0.25) + + async def waiting_worker(): + async with pool.get_connection(): + pass + + coro = asyncio.create_task(waiting_worker()) # Join the queue + await asyncio.sleep(0.1) # Give it some time to start waiting + assert pool._waiters == 1, "Worker should be waiting, we're at burst_limit already" + burst_event.set() # Allow worker holding burst connection to finish + await burst_worker # Wait for it to release the connection + assert ( + not did_call_close_connection.is_set() + ), "Did not churn the burst connection while there was a waiter" + await coro # Should be able to take that burst connection we created + assert ( + did_call_close_connection.is_set() + ), "No more waiters, burst connection should be closed" + assert ( + pool._total == pool.max_size + ), "Pool should return to max size after burst capacity is not needed" + main_event.set() # Allow the initial workers to exit + await asyncio.gather(*coros) # Wait for initial workers to exit + assert ( + pool.available.qsize() == pool.max_size + ), "Workers should return their connections to the pool" + + +@pytest.mark.asyncio +async def test_stale_connections(pool_cls): + """Test that the pool doesn't hand out closed connections.""" + + stale_connections = {1, 2, 3, 4} + + class Strategy(ConnectionStrategy): + def __init__(self): + from itertools import count + + self.it = iter(count()) + + async def make_connection(self): + return next(self.it) + + def connection_is_closed(self, conn): + return conn in stale_connections + + def close_connection(self, conn): + stale_connections.add(conn) + + pool = pool_cls(strategy=Strategy(), max_size=10) + + async def worker(): + async with pool.get_connection() as c: + return c + + conns = await asyncio.gather(*[worker() for _ in range(10)]) + assert not stale_connections & set(conns) + + pool = pool_cls(strategy=Strategy(), max_size=1) + + async with pool.get_connection() as conn: + now_stale = conn + stale_connections.add(conn) + + async with pool.get_connection() as conn: + assert ( + conn != now_stale + ), "Make sure connections closed by consumers are not given back out" + + +@pytest.mark.asyncio +async def test_handling_cancellederror(): + making_connection = asyncio.Event() + + class Strategy(ConnectionStrategy): + async def make_connection(self): + making_connection.set() + await asyncio.Event().wait() # wait forever + return 1 + + def connection_is_closed(self, conn): + return False + + def close_connection(self, conn): + pass + + pool: TracingConnectionPool[int] = TracingConnectionPool( + strategy=Strategy(), max_size=3, service_name="test" + ) + cancelled = asyncio.Event() + + async def worker(): + try: + async with pool.get_connection(): + pass + finally: + cancelled.set() + + t = asyncio.create_task(worker()) + await making_connection.wait() + + assert pool._connections_acquiring == 1 + t.cancel() + await cancelled.wait() + assert pool._connections_acquiring == 0