Skip to content

Commit 2df2056

Browse files
Simplify cache namespace and key encoding logic for v1 (#670)
Co-authored-by: Sam Bull <[email protected]>
1 parent b2036b3 commit 2df2056

17 files changed

+194
-148
lines changed

CHANGES.rst

+5
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,14 @@ Migration instructions
1212

1313
There are a number of backwards-incompatible changes. These points should help with migrating from an older release:
1414

15+
* The ``key_builder`` parameter now expects a callback which accepts 2 strings and returns a string in all cache implementations, making the builders simpler and interchangeable.
1516
* The ``key`` parameter has been removed from the ``cached`` decorator. The behaviour can be easily reimplemented with ``key_builder=lambda *a, **kw: "foo"``
1617
* When using the ``key_builder`` parameter in ``@multicached``, the function will now return the original, unmodified keys, only using the transformed keys in the cache (this has always been the documented behaviour, but not the implemented behaviour).
1718
* ``BaseSerializer`` is now an ``ABC``, so cannot be instantiated directly.
19+
* If subclassing ``BaseCache`` to implement a custom backend:
20+
21+
* The cache key type used by the backend must now be specified when inheriting (e.g. ``BaseCache[str]`` typically).
22+
* The ``build_key()`` method must now be defined (this should generally involve calling ``self._str_build_key()`` as a helper).
1823

1924

2025
0.12.0 (2023-01-13)

aiocache/__init__.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import logging
2-
from typing import Dict, Type
2+
from typing import Any, Dict, Type
33

44
from .backends.memory import SimpleMemoryCache
55
from .base import BaseCache
@@ -8,7 +8,7 @@
88

99
logger = logging.getLogger(__name__)
1010

11-
AIOCACHE_CACHES: Dict[str, Type[BaseCache]] = {SimpleMemoryCache.NAME: SimpleMemoryCache}
11+
AIOCACHE_CACHES: Dict[str, Type[BaseCache[Any]]] = {SimpleMemoryCache.NAME: SimpleMemoryCache}
1212

1313
try:
1414
import redis

aiocache/backends/memcached.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
import asyncio
2+
from typing import Optional
23

34
import aiomcache
45

56
from aiocache.base import BaseCache
67
from aiocache.serializers import JsonSerializer
78

89

9-
class MemcachedBackend(BaseCache):
10+
class MemcachedBackend(BaseCache[bytes]):
1011
def __init__(self, endpoint="127.0.0.1", port=11211, pool_size=2, **kwargs):
1112
super().__init__(**kwargs)
1213
self.endpoint = endpoint
@@ -130,7 +131,7 @@ class MemcachedCache(MemcachedBackend):
130131
:param serializer: obj derived from :class:`aiocache.serializers.BaseSerializer`.
131132
:param plugins: list of :class:`aiocache.plugins.BasePlugin` derived classes.
132133
:param namespace: string to use as default prefix for the key used in all operations of
133-
the backend. Default is None
134+
the backend. Default is an empty string, "".
134135
:param timeout: int or float in seconds specifying maximum timeout for the operations to last.
135136
By default its 5.
136137
:param endpoint: str with the endpoint to connect to. Default is 127.0.0.1.
@@ -147,8 +148,8 @@ def __init__(self, serializer=None, **kwargs):
147148
def parse_uri_path(cls, path):
148149
return {}
149150

150-
def _build_key(self, key, namespace=None):
151-
ns_key = super()._build_key(key, namespace=namespace).replace(" ", "_")
151+
def build_key(self, key: str, namespace: Optional[str] = None) -> bytes:
152+
ns_key = self._str_build_key(key, namespace).replace(" ", "_")
152153
return str.encode(ns_key)
153154

154155
def __repr__(self): # pragma: no cover

aiocache/backends/memory.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
import asyncio
2-
from typing import Dict
2+
from typing import Dict, Optional
33

44
from aiocache.base import BaseCache
55
from aiocache.serializers import NullSerializer
66

77

8-
class SimpleMemoryBackend(BaseCache):
8+
class SimpleMemoryBackend(BaseCache[str]):
99
"""
1010
Wrapper around dict operations to use it as a cache backend
1111
"""
@@ -118,7 +118,7 @@ class SimpleMemoryCache(SimpleMemoryBackend):
118118
:param serializer: obj derived from :class:`aiocache.serializers.BaseSerializer`.
119119
:param plugins: list of :class:`aiocache.plugins.BasePlugin` derived classes.
120120
:param namespace: string to use as default prefix for the key used in all operations of
121-
the backend. Default is None.
121+
the backend. Default is an empty string, "".
122122
:param timeout: int or float in seconds specifying maximum timeout for the operations to last.
123123
By default its 5.
124124
"""
@@ -131,3 +131,6 @@ def __init__(self, serializer=None, **kwargs):
131131
@classmethod
132132
def parse_uri_path(cls, path):
133133
return {}
134+
135+
def build_key(self, key: str, namespace: Optional[str] = None) -> str:
136+
return self._str_build_key(key, namespace)

aiocache/backends/redis.py

+25-14
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,21 @@
11
import itertools
22
import warnings
3+
from typing import Any, Callable, Optional, TYPE_CHECKING
34

45
import redis.asyncio as redis
56
from redis.exceptions import ResponseError as IncrbyException
67

7-
from aiocache.base import BaseCache, _ensure_key
8+
from aiocache.base import BaseCache
89
from aiocache.serializers import JsonSerializer
910

11+
if TYPE_CHECKING: # pragma: no cover
12+
from aiocache.serializers import BaseSerializer
13+
1014

1115
_NOT_SET = object()
1216

1317

14-
class RedisBackend(BaseCache):
18+
class RedisBackend(BaseCache[str]):
1519
RELEASE_SCRIPT = (
1620
"if redis.call('get',KEYS[1]) == ARGV[1] then"
1721
" return redis.call('del',KEYS[1])"
@@ -186,7 +190,7 @@ class RedisCache(RedisBackend):
186190
:param serializer: obj derived from :class:`aiocache.serializers.BaseSerializer`.
187191
:param plugins: list of :class:`aiocache.plugins.BasePlugin` derived classes.
188192
:param namespace: string to use as default prefix for the key used in all operations of
189-
the backend. Default is None.
193+
the backend. Default is an empty string, "".
190194
:param timeout: int or float in seconds specifying maximum timeout for the operations to last.
191195
By default its 5.
192196
:param endpoint: str with the endpoint to connect to. Default is "127.0.0.1".
@@ -199,8 +203,21 @@ class RedisCache(RedisBackend):
199203

200204
NAME = "redis"
201205

202-
def __init__(self, serializer=None, **kwargs):
203-
super().__init__(serializer=serializer or JsonSerializer(), **kwargs)
206+
def __init__(
207+
self,
208+
serializer: Optional["BaseSerializer"] = None,
209+
namespace: str = "",
210+
key_builder: Optional[Callable[[str, str], str]] = None,
211+
**kwargs: Any,
212+
):
213+
super().__init__(
214+
serializer=serializer or JsonSerializer(),
215+
namespace=namespace,
216+
key_builder=key_builder or (
217+
lambda key, namespace: f"{namespace}:{key}" if namespace else key
218+
),
219+
**kwargs,
220+
)
204221

205222
@classmethod
206223
def parse_uri_path(cls, path):
@@ -218,14 +235,8 @@ def parse_uri_path(cls, path):
218235
options["db"] = db
219236
return options
220237

221-
def _build_key(self, key, namespace=None):
222-
if namespace is not None:
223-
return "{}{}{}".format(
224-
namespace, ":" if namespace else "", _ensure_key(key))
225-
if self.namespace is not None:
226-
return "{}{}{}".format(
227-
self.namespace, ":" if self.namespace else "", _ensure_key(key))
228-
return key
229-
230238
def __repr__(self): # pragma: no cover
231239
return "RedisCache ({}:{})".format(self.endpoint, self.port)
240+
241+
def build_key(self, key: str, namespace: Optional[str] = None) -> str:
242+
return self._str_build_key(key, namespace)

aiocache/base.py

+47-50
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,22 @@
33
import logging
44
import os
55
import time
6+
from abc import abstractmethod
67
from enum import Enum
78
from types import TracebackType
8-
from typing import Callable, Optional, Set, Type
9+
from typing import Callable, Generic, List, Optional, Set, TYPE_CHECKING, Type, TypeVar
910

10-
from aiocache import serializers
11+
from aiocache.serializers import StringSerializer
12+
13+
if TYPE_CHECKING: # pragma: no cover
14+
from aiocache.plugins import BasePlugin
15+
from aiocache.serializers import BaseSerializer
1116

1217

1318
logger = logging.getLogger(__name__)
1419

1520
SENTINEL = object()
21+
CacheKeyType = TypeVar("CacheKeyType")
1622

1723

1824
class API:
@@ -87,7 +93,7 @@ async def _plugins(self, *args, **kwargs):
8793
return _plugins
8894

8995

90-
class BaseCache:
96+
class BaseCache(Generic[CacheKeyType]):
9197
"""
9298
Base class that agregates the common logic for the different caches that may exist. Cache
9399
related available options are:
@@ -97,9 +103,9 @@ class BaseCache:
97103
:param plugins: list of :class:`aiocache.plugins.BasePlugin` derived classes. Default is empty
98104
list.
99105
:param namespace: string to use as default prefix for the key used in all operations of
100-
the backend. Default is None
106+
the backend. Default is an empty string, "".
101107
:param key_builder: alternative callable to build the key. Receives the key and the namespace
102-
as params and should return something that can be used as key by the underlying backend.
108+
as params and should return a string that can be used as a key by the underlying backend.
103109
:param timeout: int or float in seconds specifying maximum timeout for the operations to last.
104110
By default its 5. Use 0 or None if you want to disable it.
105111
:param ttl: int the expiration time in seconds to use as a default in all operations of
@@ -109,18 +115,22 @@ class BaseCache:
109115
NAME: str
110116

111117
def __init__(
112-
self, serializer=None, plugins=None, namespace=None, key_builder=None, timeout=5, ttl=None
118+
self,
119+
serializer: Optional["BaseSerializer"] = None,
120+
plugins: Optional[List["BasePlugin"]] = None,
121+
namespace: str = "",
122+
key_builder: Callable[[str, str], str] = lambda key, namespace: f"{namespace}{key}",
123+
timeout: Optional[float] = 5,
124+
ttl: Optional[float] = None,
113125
):
114-
self.timeout = float(timeout) if timeout is not None else timeout
115-
self.namespace = namespace
116-
self.ttl = float(ttl) if ttl is not None else ttl
117-
self.build_key = key_builder or self._build_key
126+
self.timeout = float(timeout) if timeout is not None else None
127+
self.ttl = float(ttl) if ttl is not None else None
118128

119-
self._serializer = None
120-
self.serializer = serializer or serializers.StringSerializer()
129+
self.namespace = namespace
130+
self._build_key = key_builder
121131

122-
self._plugins = None
123-
self.plugins = plugins or []
132+
self._serializer = serializer or StringSerializer()
133+
self._plugins = plugins or []
124134

125135
@property
126136
def serializer(self):
@@ -162,9 +172,8 @@ async def add(self, key, value, ttl=SENTINEL, dumps_fn=None, namespace=None, _co
162172
- :class:`asyncio.TimeoutError` if it lasts more than self.timeout
163173
"""
164174
start = time.monotonic()
165-
dumps = dumps_fn or self._serializer.dumps
166-
ns = namespace if namespace is not None else self.namespace
167-
ns_key = self.build_key(key, namespace=ns)
175+
dumps = dumps_fn or self.serializer.dumps
176+
ns_key = self.build_key(key, namespace)
168177

169178
await self._add(ns_key, dumps(value), ttl=self._get_ttl(ttl), _conn=_conn)
170179

@@ -192,9 +201,8 @@ async def get(self, key, default=None, loads_fn=None, namespace=None, _conn=None
192201
:raises: :class:`asyncio.TimeoutError` if it lasts more than self.timeout
193202
"""
194203
start = time.monotonic()
195-
loads = loads_fn or self._serializer.loads
196-
ns = namespace if namespace is not None else self.namespace
197-
ns_key = self.build_key(key, namespace=ns)
204+
loads = loads_fn or self.serializer.loads
205+
ns_key = self.build_key(key, namespace)
198206

199207
value = loads(await self._get(ns_key, encoding=self.serializer.encoding, _conn=_conn))
200208

@@ -224,10 +232,9 @@ async def multi_get(self, keys, loads_fn=None, namespace=None, _conn=None):
224232
:raises: :class:`asyncio.TimeoutError` if it lasts more than self.timeout
225233
"""
226234
start = time.monotonic()
227-
loads = loads_fn or self._serializer.loads
228-
ns = namespace if namespace is not None else self.namespace
235+
loads = loads_fn or self.serializer.loads
229236

230-
ns_keys = [self.build_key(key, namespace=ns) for key in keys]
237+
ns_keys = [self.build_key(key, namespace) for key in keys]
231238
values = [
232239
loads(value)
233240
for value in await self._multi_get(
@@ -269,9 +276,8 @@ async def set(
269276
:raises: :class:`asyncio.TimeoutError` if it lasts more than self.timeout
270277
"""
271278
start = time.monotonic()
272-
dumps = dumps_fn or self._serializer.dumps
273-
ns = namespace if namespace is not None else self.namespace
274-
ns_key = self.build_key(key, namespace=ns)
279+
dumps = dumps_fn or self.serializer.dumps
280+
ns_key = self.build_key(key, namespace)
275281

276282
res = await self._set(
277283
ns_key, dumps(value), ttl=self._get_ttl(ttl), _cas_token=_cas_token, _conn=_conn
@@ -303,12 +309,11 @@ async def multi_set(self, pairs, ttl=SENTINEL, dumps_fn=None, namespace=None, _c
303309
:raises: :class:`asyncio.TimeoutError` if it lasts more than self.timeout
304310
"""
305311
start = time.monotonic()
306-
dumps = dumps_fn or self._serializer.dumps
307-
ns = namespace if namespace is not None else self.namespace
312+
dumps = dumps_fn or self.serializer.dumps
308313

309314
tmp_pairs = []
310315
for key, value in pairs:
311-
tmp_pairs.append((self.build_key(key, namespace=ns), dumps(value)))
316+
tmp_pairs.append((self.build_key(key, namespace), dumps(value)))
312317

313318
await self._multi_set(tmp_pairs, ttl=self._get_ttl(ttl), _conn=_conn)
314319

@@ -339,8 +344,7 @@ async def delete(self, key, namespace=None, _conn=None):
339344
:raises: :class:`asyncio.TimeoutError` if it lasts more than self.timeout
340345
"""
341346
start = time.monotonic()
342-
ns = namespace if namespace is not None else self.namespace
343-
ns_key = self.build_key(key, namespace=ns)
347+
ns_key = self.build_key(key, namespace)
344348
ret = await self._delete(ns_key, _conn=_conn)
345349
logger.debug("DELETE %s %d (%.4f)s", ns_key, ret, time.monotonic() - start)
346350
return ret
@@ -364,8 +368,7 @@ async def exists(self, key, namespace=None, _conn=None):
364368
:raises: :class:`asyncio.TimeoutError` if it lasts more than self.timeout
365369
"""
366370
start = time.monotonic()
367-
ns = namespace if namespace is not None else self.namespace
368-
ns_key = self.build_key(key, namespace=ns)
371+
ns_key = self.build_key(key, namespace)
369372
ret = await self._exists(ns_key, _conn=_conn)
370373
logger.debug("EXISTS %s %d (%.4f)s", ns_key, ret, time.monotonic() - start)
371374
return ret
@@ -392,8 +395,7 @@ async def increment(self, key, delta=1, namespace=None, _conn=None):
392395
:raises: :class:`TypeError` if value is not incrementable
393396
"""
394397
start = time.monotonic()
395-
ns = namespace if namespace is not None else self.namespace
396-
ns_key = self.build_key(key, namespace=ns)
398+
ns_key = self.build_key(key, namespace)
397399
ret = await self._increment(ns_key, delta, _conn=_conn)
398400
logger.debug("INCREMENT %s %d (%.4f)s", ns_key, ret, time.monotonic() - start)
399401
return ret
@@ -418,8 +420,7 @@ async def expire(self, key, ttl, namespace=None, _conn=None):
418420
:raises: :class:`asyncio.TimeoutError` if it lasts more than self.timeout
419421
"""
420422
start = time.monotonic()
421-
ns = namespace if namespace is not None else self.namespace
422-
ns_key = self.build_key(key, namespace=ns)
423+
ns_key = self.build_key(key, namespace)
423424
ret = await self._expire(ns_key, ttl, _conn=_conn)
424425
logger.debug("EXPIRE %s %d (%.4f)s", ns_key, ret, time.monotonic() - start)
425426
return ret
@@ -498,12 +499,15 @@ async def close(self, *args, _conn=None, **kwargs):
498499
async def _close(self, *args, **kwargs):
499500
pass
500501

501-
def _build_key(self, key, namespace=None):
502-
if namespace is not None:
503-
return "{}{}".format(namespace, _ensure_key(key))
504-
if self.namespace is not None:
505-
return "{}{}".format(self.namespace, _ensure_key(key))
506-
return key
502+
@abstractmethod
503+
def build_key(self, key: str, namespace: Optional[str] = None) -> CacheKeyType:
504+
raise NotImplementedError()
505+
506+
def _str_build_key(self, key: str, namespace: Optional[str] = None) -> str:
507+
"""Simple key builder that can be used in subclasses for build_key()."""
508+
key_name = key.value if isinstance(key, Enum) else key
509+
ns = self.namespace if namespace is None else namespace
510+
return self._build_key(key_name, ns)
507511

508512
def _get_ttl(self, ttl):
509513
return ttl if ttl is not SENTINEL else self.ttl
@@ -550,12 +554,5 @@ async def _do_inject_conn(self, *args, **kwargs):
550554
return _do_inject_conn
551555

552556

553-
def _ensure_key(key):
554-
if isinstance(key, Enum):
555-
return key.value
556-
else:
557-
return key
558-
559-
560557
for cmd in API.CMDS:
561558
setattr(_Conn, cmd.__name__, _Conn._inject_conn(cmd.__name__))

0 commit comments

Comments
 (0)