"""Redis cache backend."""
import pickle
import random
import re
from django.core.cache.backends.base import DEFAULT_TIMEOUT, BaseCache
from django.utils.functional import cached_property
from django.utils.module_loading import import_string
class RedisSerializer:
def __init__(self, protocol=None):
self.protocol = pickle.HIGHEST_PROTOCOL if protocol is None else protocol
def dumps(self, obj):
# Only skip pickling for integers, a int subclasses as bool should be
# pickled.
if type(obj) is int:
return obj
return pickle.dumps(obj, self.protocol)
def loads(self, data):
try:
return int(data)
except ValueError:
return pickle.loads(data)
class RedisCacheClient:
def __init__(
self,
servers,
serializer=None,
pool_class=None,
parser_class=None,
**options,
):
import redis
self._lib = redis
self._servers = servers
self._pools = {}
self._client = self._lib.Redis
if isinstance(pool_class, str):
pool_class = import_string(pool_class)
self._pool_class = pool_class or self._lib.ConnectionPool
if isinstance(serializer, str):
serializer = import_string(serializer)
if callable(serializer):
serializer = serializer()
self._serializer = serializer or RedisSerializer()
if isinstance(parser_class, str):
parser_class = import_string(parser_class)
parser_class = parser_class or self._lib.connection.DefaultParser
self._pool_options = {"parser_class": parser_class, **options}
def _get_connection_pool_index(self, write):
# Write to the first server. Read from other servers if there are more,
# otherwise read from the first server.
if write or len(self._servers) == 1:
return 0
return random.randint(1, len(self._servers) - 1)
def _get_connection_pool(self, write):
index = self._get_connection_pool_index(write)
if index not in self._pools:
self._pools[index] = self._pool_class.from_url(
self._servers[index],
**self._pool_options,
)
return self._pools[index]
def get_client(self, key=None, *, write=False):
# key is used so that the method signature remains the same and custom
# cache client can be implemented which might require the key to select
# the server, e.g. sharding.
pool = self._get_connection_pool(write)
return self._client(connection_pool=pool)
def add(self, key, value, timeout):
client = self.get_client(key, write=True)
value = self._serializer.dumps(value)
if timeout == 0:
if ret := bool(client.set(key, value, nx=True)):
client.delete(key)
return ret
else:
return bool(client.set(key, value, ex=timeout, nx=True))
def get(self, key, default):
client = self.get_client(key)
value = client.get(key)
return default if value is None else self._serializer.loads(value)
def set(self, key, value, timeout):
client = self.get_client(key, write=True)
value = self._serializer.dumps(value)
if timeout == 0:
client.delete(key)
else:
client.set(key, value, ex=timeout)
def touch(self, key, timeout):
client = self.get_client(key, write=True)
if timeout is None:
return bool(client.persist(key))
else:
return bool(client.expire(key, timeout))
def delete(self, key):
client = self.get_client(key, write=True)
return bool(client.delete(key))
def get_many(self, keys):
client = self.get_client(None)
ret = client.mget(keys)
return {
k: self._serializer.loads(v) for k, v in zip(keys, ret) if v is not None
}
def has_key(self, key):
client = self.get_client(key)
return bool(client.exists(key))
def incr(self, key, delta):
client = self.get_client(key, write=True)
if not client.exists(key):
raise ValueError("Key '%s' not found." % key)
return client.incr(key, delta)
def set_many(self, data, timeout):
client = self.get_client(None, write=True)
pipeline = client.pipeline()
pipeline.mset({k: self._serializer.dumps(v) for k, v in data.items()})
if timeout is not None:
# Setting timeout for each key as redis does not support timeout
# with mset().
for key in data:
pipeline.expire(key, timeout)
pipeline.execute()
def delete_many(self, keys):
client = self.get_client(None, write=True)
client.delete(*keys)
def clear(self):
client = self.get_client(None, write=True)
return bool(client.flushdb())
class RedisCache(BaseCache):
def __init__(self, server, params):
super().__init__(params)
if isinstance(server, str):
self._servers = re.split("[;,]", server)
else:
self._servers = server
self._class = RedisCacheClient
self._options = params.get("OPTIONS", {})
@cached_property
def _cache(self):
return self._class(self._servers, **self._options)
def get_backend_timeout(self, timeout=DEFAULT_TIMEOUT):
if timeout == DEFAULT_TIMEOUT:
timeout = self.default_timeout
# The key will be made persistent if None used as a timeout.
# Non-positive values will cause the key to be deleted.
return None if timeout is None else max(0, int(timeout))
def add(self, key, value, timeout=DEFAULT_TIMEOUT, version=None):
key = self.make_and_validate_key(key, version=version)
return self._cache.add(key, value, self.get_backend_timeout(timeout))
def get(self, key, default=None, version=None):
key = self.make_and_validate_key(key, version=version)
return self._cache.get(key, default)
def set(self, key, value, timeout=DEFAULT_TIMEOUT, version=None):
key = self.make_and_validate_key(key, version=version)
self._cache.set(key, value, self.get_backend_timeout(timeout))
def touch(self, key, timeout=DEFAULT_TIMEOUT, version=None):
key = self.make_and_validate_key(key, version=version)
return self._cache.touch(key, self.get_backend_timeout(timeout))
def delete(self, key, version=None):
key = self.make_and_validate_key(key, version=version)
return self._cache.delete(key)
def get_many(self, keys, version=None):
key_map = {
self.make_and_validate_key(key, version=version): key for key in keys
}
ret = self._cache.get_many(key_map.keys())
return {key_map[k]: v for k, v in ret.items()}
def has_key(self, key, version=None):
key = self.make_and_validate_key(key, version=version)
return self._cache.has_key(key)
def incr(self, key, delta=1, version=None):
key = self.make_and_validate_key(key, version=version)
return self._cache.incr(key, delta)
def set_many(self, data, timeout=DEFAULT_TIMEOUT, version=None):
if not data:
return []
safe_data = {}
for key, value in data.items():
key = self.make_and_validate_key(key, version=version)
safe_data[key] = value
self._cache.set_many(safe_data, self.get_backend_timeout(timeout))
return []
def delete_many(self, keys, version=None):
if not keys:
return
safe_keys = [self.make_and_validate_key(key, version=version) for key in keys]
self._cache.delete_many(safe_keys)
def clear(self):
return self._cache.clear()