Skip to content

Commit

Permalink
Merge pull request #4 from fellowapp/async-close-connection
Browse files Browse the repository at this point in the history
Async ConnectionStrategy.close_connection
  • Loading branch information
p7g authored Dec 20, 2021
2 parents 875058b + 5e513cf commit a99092e
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 47 deletions.
21 changes: 11 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,19 @@ This is a generic, high-throughput, optionally-burstable pool for asyncio.

Some cool features:

- No locking aside from the GIL; no `asyncio.Lock` or `asyncio.Condition` needs
to be taken in order to get a connection.
- No locking[^1]; no `asyncio.Lock` or `asyncio.Condition` needs to be taken in
order to get a connection.
- Available connections are retrieved without yielding to the event loop.
- When `burst_limit` is specified, `max_size` acts as a "soft" limit; the pool
can go beyond this limit to handle increased load, and shrinks back down
after.
- The contents of the pool can be anything; just implement a
`ConnectionStrategy`.

[^1]: Theoretically, there is an implicit "lock" that is held while an asyncio
task is executing. No other task can execute until the current task
yields (since it's cooperative multitasking), so any operations during
that time are atomic.

## Why?

Expand Down Expand Up @@ -107,24 +111,21 @@ exception is suppressed unless it is not a `BaseException`, like
implementation to avoid leaking a connection in this case.


#### `def close_connection(self, conn: Conn)`
#### `async def close_connection(self, conn: Conn)`

This method is called to close a connection. This occurs when the pool has
exceeded `max_size` (i.e. it is bursting) and a connection is returned that is
no longer needed (i.e. there are no more consumers waiting for a connection).

Note that this method is synchronous; if closing a connection is an
asynchronous operation, `asyncio.create_task` can be used.

If this method raises an exception, the connection is dropped and the exception
bubbles to the caller of `ConnectionPool.get_connection().__aexit__` (usually
an `async with` block).
If this method raises an exception, the connection is assumed to be closed and
the exception bubbles to the caller of `ConnectionPool.get_connection().__aexit__`
(usually an `async with` block).


## Integrations with 3rd-party libraries

This package includes support for [`ddtrace`][ddtrace]/[`datadog`][datadog] and
for [`aioredis`][aioredis].
for [`aioredis`][aioredis] (<2.0.0).

[ddtrace]: https://github.com/datadog/dd-trace-py
[datadog]: https://github.com/datadog/datadogpy
Expand Down
74 changes: 45 additions & 29 deletions asyncio_connection_pool/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from abc import ABC, abstractmethod
import asyncio
import inspect
from contextlib import asynccontextmanager
from typing import AsyncIterator, Awaitable, Generic, Optional, TypeVar

Expand All @@ -17,33 +18,39 @@ def connection_is_closed(self, conn: Conn) -> bool:
...

@abstractmethod
def close_connection(self, conn: Conn) -> None:
async def close_connection(self, conn: Conn) -> None:
...


async def _close_connection_compat(
strategy: ConnectionStrategy[Conn], conn: Conn
) -> None:
result = strategy.close_connection(conn)
if inspect.isawaitable(result):
await result


class ConnectionPool(Generic[Conn]):
"""A high-throughput, optionally-burstable pool free of explicit locking.
NOTE: Not threadsafe. Do not share across threads.
This threadpool offers high throughput by avoiding the need for an explicit
lock to retrieve a connection. This is possible by taking advantage of
Pythons global interpreter lock (GIL).
This connection pool offers high throughput by avoiding the need for an
explicit lock to retrieve a connection. This is possible by taking
advantage of cooperative multitasking with asyncio.
If the optional `burst_limit` argument is supplied, the `max_size` argument
will act as a "soft" maximum. When there is demand, more connections will
be opened to satisfy it, up to `burst_limit`. When these connections are no
longer needed, they will be closed. This way we can avoid holding many open
connections for extended times.
Since we make use of the GIL, this pool should not be shared across
threads. This is unsafe because some C extensions release the GIL when
waiting on IO, for example. In that case, a different thread can actually
execute concurrently, which breaks the assumptions upon which this pool is
based.
This implementation assumes that all operations that do not await are
atomic. Since CPython can switch thread contexts between each evaluated op
code, it is not safe to share an instance of this pool between threads.
This pool is generic over the type of connection it holds, which can be
anything. Any implementation dependent logic belongs in the
anything. Any logic specific to the connection type belongs in the
ConnectionStrategy, which should be passed to the pool's constructor via
the `strategy` parameter.
"""
Expand All @@ -63,6 +70,7 @@ def __init__(
raise ValueError("burst_limit must be greater than or equal to max_size")
self.in_use = 0
self.currently_allocating = 0
self.currently_deallocating = 0
self.available: "asyncio.Queue[Conn]" = asyncio.Queue(maxsize=self.max_size)

@property
Expand Down Expand Up @@ -90,11 +98,10 @@ async def _connection_waiter(self):
def _get_conn(self) -> "Awaitable[Conn]":
# This function is how we avoid explicitly locking. Since it is
# synchronous, we do all the "book-keeping" required to get a
# connection synchronously (i.e. implicitly holding the GIL), and
# return a Future or Task which can be awaited after this function
# returns.
# connection synchronously, and return a Future or Task which can be
# awaited after this function returns.
#
# The most important thing here is that we have the GIL from when we
# The most important thing here is that we do not await from when we
# measure values like `self._total` or `self.available.empty()` until
# we change values that affect those measurements. In other words,
# taking a connection must be an atomic operation.
Expand All @@ -114,13 +121,14 @@ def _get_conn(self) -> "Awaitable[Conn]":
# Returns a Task that resolves to the new connection, which can be
# awaited.
#
# If there are a lot of threads waiting for a connection, to avoid
# If there are a lot of tasks waiting for a connection, to avoid
# having all of them time out and be cancelled, we'll burst to
# higher max_size.
self.currently_allocating += 1
return self._loop.create_task(self._connection_maker())
else:
# Return a Task that waits for the next task to appear in the queue.
# Return a Task that waits for the next connection to appear in the
# queue.
return self._loop.create_task(self._connection_waiter())

@asynccontextmanager
Expand All @@ -144,22 +152,30 @@ async def get_connection(self) -> AsyncIterator[Conn]: # type: ignore
yield conn
finally:
# Return the connection to the pool.
self.in_use -= 1
assert self.in_use >= 0, "More connections returned than given"

self.currently_deallocating += 1
try:
# Check if we are currently over-committed (i.e. bursting)
if self._total >= self.max_size and self._waiters == 0:
if (
self._total - self.currently_deallocating >= self.max_size
and self._waiters == 0
):
# We had created extra connections to handle burst load,
# but there are no more waiters, so we don't need this
# connection anymore.
self.strategy.close_connection(conn)
await _close_connection_compat(self.strategy, conn)
else:
self.available.put_nowait(conn)
except asyncio.QueueFull:
# We don't actually check if the queue has room before trying
# to put the connection into it. It's unclear whether we could
# have a full queue and still have waiters, but we should
# handle this case to be safe (otherwise we would leak
# connections).
self.strategy.close_connection(conn)
try:
self.available.put_nowait(conn)
except asyncio.QueueFull:
# We don't actually check if the queue has room before
# trying to put the connection into it. It's unclear
# whether we could have a full queue and still have
# waiters, but we should handle this case to be safe
# (otherwise we would leak connections).
await _close_connection_compat(self.strategy, conn)
finally:
# Consider the connection closed even if an exception is raised
# in the strategy's close_connection.
self.currently_deallocating -= 1
self.in_use -= 1
assert self.in_use >= 0, "More connections returned than given"
3 changes: 2 additions & 1 deletion asyncio_connection_pool/contrib/aioredis.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,6 @@ async def make_connection(self):
def connection_is_closed(self, conn):
return conn.closed

def close_connection(self, conn):
async def close_connection(self, conn):
conn.close()
await conn.wait_closed()
12 changes: 5 additions & 7 deletions test/test_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def pool_cls(request):
return request.param


class RandomIntStrategy(ConnectionStrategy):
class RandomIntStrategy(ConnectionStrategy[int]):
async def make_connection(self):
import random

Expand Down Expand Up @@ -62,9 +62,7 @@ async def inc(self):

@pytest.mark.asyncio
async def test_concurrent_get_connection(pool_cls):
"""Test handling several connection requests in a short time. (Not truly
concurrent because of the GIL)
"""
"""Test handling several connection requests in a short time."""

pool = pool_cls(strategy=RandomIntStrategy(), max_size=20)
nworkers = 10
Expand Down Expand Up @@ -97,7 +95,7 @@ async def test_currently_allocating(pool_cls):

ev = asyncio.Event()

class WaitStrategy(ConnectionStrategy):
class WaitStrategy(ConnectionStrategy[None]):
async def make_connection(self):
await ev.wait()

Expand Down Expand Up @@ -214,7 +212,7 @@ async def test_stale_connections(pool_cls):

stale_connections = {1, 2, 3, 4}

class Strategy(ConnectionStrategy):
class Strategy(ConnectionStrategy[int]):
def __init__(self):
from itertools import count

Expand Down Expand Up @@ -254,7 +252,7 @@ async def worker():
async def test_handling_cancellederror():
making_connection = asyncio.Event()

class Strategy(ConnectionStrategy):
class Strategy(ConnectionStrategy[int]):
async def make_connection(self):
making_connection.set()
await asyncio.Event().wait() # wait forever
Expand Down

0 comments on commit a99092e

Please sign in to comment.