Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement simple asyncio wrapper API with basic tests #646

Draft
wants to merge 8 commits into
base: master
Choose a base branch
from
3 changes: 3 additions & 0 deletions kazoo/aio/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
"""
Simple asyncio integration of the threaded async executor engine.
"""
92 changes: 92 additions & 0 deletions kazoo/aio/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import asyncio

from kazoo.aio.handler import AioSequentialThreadingHandler
from kazoo.client import KazooClient, TransactionRequest


class AioKazooClient(KazooClient):
"""
The asyncio compatibility mostly mimics the behaviour of the base async
one. All calls are wrapped in asyncio.shield() to prevent cancellation
that is not supported in the base async implementation.

The sync and base-async API are still completely functional. Mixing the
use of any of the 3 should be okay.
"""

def __init__(self, *args, **kwargs):
if not kwargs.get("handler"):
kwargs["handler"] = AioSequentialThreadingHandler()
KazooClient.__init__(self, *args, **kwargs)

# asyncio compatible api wrappers
async def start_aio(self, timeout=15):
"""
There is no protection for calling this multiple times in parallel.
The start_async() seems to lack that as well. Maybe it is allowed and
handled internally.
"""
await self.handler.loop.run_in_executor(None, self.start, timeout)

async def add_auth_aio(self, *args, **kwargs):
return await asyncio.shield(
self.add_auth_async(*args, **kwargs).future
)

async def sync_aio(self, *args, **kwargs):
return await asyncio.shield(self.sync_async(*args, **kwargs).future)

async def create_aio(self, *args, **kwargs):
return await asyncio.shield(self.create_async(*args, **kwargs).future)

async def ensure_path_aio(self, *args, **kwargs):
return await asyncio.shield(
self.ensure_path_async(*args, **kwargs).future
)

async def exists_aio(self, *args, **kwargs):
return await asyncio.shield(self.exists_async(*args, **kwargs).future)

async def get_aio(self, *args, **kwargs):
return await asyncio.shield(self.get_async(*args, **kwargs).future)

async def get_children_aio(self, *args, **kwargs):
return await asyncio.shield(
self.get_children_async(*args, **kwargs).future
)

async def get_acls_aio(self, *args, **kwargs):
return await asyncio.shield(
self.get_acls_async(*args, **kwargs).future
)

async def set_acls_aio(self, *args, **kwargs):
return await asyncio.shield(
self.set_acls_async(*args, **kwargs).future
)

async def set_aio(self, *args, **kwargs):
return await asyncio.shield(self.set_async(*args, **kwargs).future)

def transaction_aio(self):
return AioTransactionRequest(self)

async def delete_aio(self, *args, **kwargs):
return await asyncio.shield(self.delete_async(*args, **kwargs).future)

async def reconfig_aio(self, *args, **kwargs):
return await asyncio.shield(
self.reconfig_async(*args, **kwargs).future
)


class AioTransactionRequest(TransactionRequest):
async def commit_aio(self):
return await asyncio.shield(self.commit_async().future)

async def __aenter__(self):
return self

async def __aexit__(self, exc_type, exc_value, exc_tb):
if not exc_type:
await self.commit_aio()
60 changes: 60 additions & 0 deletions kazoo/aio/handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import asyncio
import threading

from kazoo.handlers.threading import AsyncResult, SequentialThreadingHandler


class AioAsyncResult(AsyncResult):
def __init__(self, handler):
self.future = handler.loop.create_future()
AsyncResult.__init__(self, handler)

def set(self, value=None):
"""
The completion of the future has the same guarantees as the
notification emitting of the condition.
Provided that no callbacks raise it will complete.
"""
AsyncResult.set(self, value)
self._handler.loop.call_soon_threadsafe(self.future.set_result, value)

def set_exception(self, exception):
"""
The completion of the future has the same guarantees as the
notification emitting of the condition.
Provided that no callbacks raise it will complete.
"""
AsyncResult.set_exception(self, exception)
self._handler.loop.call_soon_threadsafe(
self.future.set_exception, exception
)


class AioSequentialThreadingHandler(SequentialThreadingHandler):
def __init__(self):
"""
Creating the handler must be done on the asyncio-loop's thread.
"""
self.loop = asyncio.get_running_loop()
self._aio_thread = threading.current_thread()
SequentialThreadingHandler.__init__(self)

def async_result(self, api=False):
"""
Almost all async-result objects are created by a method that is
invoked from the user's thead. The one exception I'm aware of is
in the PatientChildrenWatch utility, that creates an async-result
in its worker thread. Just because of that it is imperative to
only create asyncio compatible results when the invoking code is
from the loop's thread. There is no PEP/API guarantee that
implementing the create_future() has to be thread-safe. The default
is mostly thread-safe. The only thing that may get synchronization
issue is a debug-feature for asyncio development. Quickly looking at
the alternate implementation of uvloop, they use the default Future
implementation, so no change there.
For now, just to be safe, we check the current thread and create an
async-result object based on the invoking thread's identity.
"""
if api and threading.current_thread() is self._aio_thread:
return AioAsyncResult(self)
return AsyncResult(self)
91 changes: 91 additions & 0 deletions kazoo/aio/retry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import asyncio
import random
import time
from functools import partial

from kazoo.exceptions import (
ConnectionClosedError,
ConnectionLoss,
OperationTimeoutError,
SessionExpiredError,
)
from kazoo.retry import ForceRetryError, RetryFailedError


EXCEPTIONS = (
ConnectionLoss,
OperationTimeoutError,
ForceRetryError,
)

EXCEPTIONS_WITH_EXPIRED = EXCEPTIONS + (SessionExpiredError,)


def kazoo_retry_aio(
max_tries=1,
delay=0.1,
backoff=2,
max_jitter=0.4,
max_delay=60.0,
ignore_expire=True,
deadline=None,
):
"""
This is similar to KazooRetry, but they do not have compatible
interfaces. The threaded and asyncio constructs are too different
to easily wrap the KazooRetry implementation. Unless, all retries
always get their own thread to work in. This is much more lightweight
compared to the object-copying and resetting implementation.

There is no equivalent analogue to the interrupt API.
If interrupting the retry is necessary, it must be wrapped in
an asyncio.Task, which can be cancelled. Be aware though that
this will quit waiting on the Zookeeper API call immediately
unlike the threaded API. There is no way to interrupt/cancel an
internal request thread so it will continue and stop eventually
on its own. This means caller can't know if the call is still
in progress and may succeed or the retry was cancelled while it
was waiting for delay.

Usage example. These are equivalent except that the latter lines
will retry the requests on specific exceptions:
await zk.create_aio("/x")
await zk.create_aio("/x/y")

aio_retry = kazoo_retry_aio()
await aio_retry(zk.create_aio, "/x")
await aio_retry(zk.create_aio, "/x/y")
"""
retry_exceptions = (
EXCEPTIONS_WITH_EXPIRED if ignore_expire else EXCEPTIONS
)
max_jitter = max(min(max_jitter, 1.0), 0.0)
get_jitter = partial(random.uniform, 1.0 - max_jitter, 1.0 + max_jitter)
del max_jitter

async def _retry(func, *args, **kwargs):
attempts = 0
cur_delay = delay
stop_time = (
None if deadline is None else time.perf_counter() + deadline
)
while True:
try:
return await func(*args, **kwargs)
except ConnectionClosedError:
raise
except retry_exceptions:
# Note: max_tries == -1 means infinite tries.
if attempts == max_tries:
raise RetryFailedError("Too many retry attempts")
attempts += 1
sleep_time = cur_delay * get_jitter()
if (
stop_time is not None
and time.perf_counter() + sleep_time >= stop_time

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

line break before binary operator

):
raise RetryFailedError("Exceeded retry deadline")
await asyncio.sleep(sleep_time)
cur_delay = min(sleep_time * backoff, max_delay)

return _retry
26 changes: 13 additions & 13 deletions kazoo/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -818,7 +818,7 @@ def add_auth_async(self, scheme, credential):
# we need this auth data to re-authenticate on reconnect
self.auth_data.add((scheme, credential))

async_result = self.handler.async_result()
async_result = self.handler.async_result(api=True)
self._call(Auth(0, scheme, credential), async_result)
return async_result

Expand All @@ -839,7 +839,7 @@ def sync_async(self, path):
:rtype: :class:`~kazoo.interfaces.IAsyncResult`

"""
async_result = self.handler.async_result()
async_result = self.handler.async_result(api=True)

@wrap(async_result)
def _sync_completion(result):
Expand Down Expand Up @@ -997,7 +997,7 @@ def create_async(self, path, value=b"", acl=None, ephemeral=False,
if acl is None:
acl = OPEN_ACL_UNSAFE

async_result = self.handler.async_result()
async_result = self.handler.async_result(api=True)

@capture_exceptions(async_result)
def do_create():
Expand Down Expand Up @@ -1071,7 +1071,7 @@ def ensure_path_async(self, path, acl=None):

"""
acl = acl or self.default_acl
async_result = self.handler.async_result()
async_result = self.handler.async_result(api=True)

@wrap(async_result)
def create_completion(result):
Expand Down Expand Up @@ -1134,7 +1134,7 @@ def exists_async(self, path, watch=None):
if watch and not callable(watch):
raise TypeError("Invalid type for 'watch' (must be a callable)")

async_result = self.handler.async_result()
async_result = self.handler.async_result(api=True)
self._call(Exists(_prefix_root(self.chroot, path), watch),
async_result)
return async_result
Expand Down Expand Up @@ -1176,7 +1176,7 @@ def get_async(self, path, watch=None):
if watch and not callable(watch):
raise TypeError("Invalid type for 'watch' (must be a callable)")

async_result = self.handler.async_result()
async_result = self.handler.async_result(api=True)
self._call(GetData(_prefix_root(self.chroot, path), watch),
async_result)
return async_result
Expand Down Expand Up @@ -1232,7 +1232,7 @@ def get_children_async(self, path, watch=None, include_data=False):
if not isinstance(include_data, bool):
raise TypeError("Invalid type for 'include_data' (bool expected)")

async_result = self.handler.async_result()
async_result = self.handler.async_result(api=True)
if include_data:
req = GetChildren2(_prefix_root(self.chroot, path), watch)
else:
Expand Down Expand Up @@ -1270,7 +1270,7 @@ def get_acls_async(self, path):
if not isinstance(path, string_types):
raise TypeError("Invalid type for 'path' (string expected)")

async_result = self.handler.async_result()
async_result = self.handler.async_result(api=True)
self._call(GetACL(_prefix_root(self.chroot, path)), async_result)
return async_result

Expand Down Expand Up @@ -1318,7 +1318,7 @@ def set_acls_async(self, path, acls, version=-1):
if not isinstance(version, int):
raise TypeError("Invalid type for 'version' (int expected)")

async_result = self.handler.async_result()
async_result = self.handler.async_result(api=True)
self._call(SetACL(_prefix_root(self.chroot, path), acls, version),
async_result)
return async_result
Expand Down Expand Up @@ -1372,7 +1372,7 @@ def set_async(self, path, value, version=-1):
if not isinstance(version, int):
raise TypeError("Invalid type for 'version' (int expected)")

async_result = self.handler.async_result()
async_result = self.handler.async_result(api=True)
self._call(SetData(_prefix_root(self.chroot, path), value, version),
async_result)
return async_result
Expand Down Expand Up @@ -1443,7 +1443,7 @@ def delete_async(self, path, version=-1):
raise TypeError("Invalid type for 'path' (string expected)")
if not isinstance(version, int):
raise TypeError("Invalid type for 'version' (int expected)")
async_result = self.handler.async_result()
async_result = self.handler.async_result(api=True)
self._call(Delete(_prefix_root(self.chroot, path), version),
async_result)
return async_result
Expand Down Expand Up @@ -1556,7 +1556,7 @@ def reconfig_async(self, joining, leaving, new_members, from_config):
if not isinstance(from_config, int):
raise TypeError("Invalid type for 'from_config' (int expected)")

async_result = self.handler.async_result()
async_result = self.handler.async_result(api=True)
reconfig = Reconfig(joining, leaving, new_members, from_config)
self._call(reconfig, async_result)

Expand Down Expand Up @@ -1672,7 +1672,7 @@ def commit_async(self):
"""
self._check_tx_state()
self.committed = True
async_object = self.client.handler.async_result()
async_object = self.client.handler.async_result(api=True)
self.client._call(Transaction(self.operations), async_object)
return async_object

Expand Down
6 changes: 4 additions & 2 deletions kazoo/handlers/threading.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,8 +271,10 @@ def rlock_object(self):
"""Create an appropriate RLock object"""
return threading.RLock()

def async_result(self):
"""Create a :class:`AsyncResult` instance"""
def async_result(self, api=False):
"""Create a :class:`AsyncResult` instance. The api flag will
indicate if this object will be used by a user code or an
internal one. It is necessary for asyncio support."""
return AsyncResult(self)

def spawn(self, func, *args, **kwargs):
Expand Down
12 changes: 10 additions & 2 deletions kazoo/testing/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,12 @@
from kazoo.testing.harness import KazooTestCase, KazooTestHarness
from kazoo.testing.harness import (
KazooAioTestCase,
KazooTestCase,
KazooTestHarness,
)


__all__ = ('KazooTestHarness', 'KazooTestCase', )
__all__ = (
"KazooTestHarness",
"KazooTestCase",
"KazooAioTestCase",
)
Loading