diff --git a/beaker/cache.py b/beaker/cache.py index 5a1ad6a..1cacfa1 100644 --- a/beaker/cache.py +++ b/beaker/cache.py @@ -6,7 +6,9 @@ :func:`.region_invalidate`. """ +import inspect import warnings +import sys from itertools import chain from beaker._compat import u_, unicode_text, func_signature, bindfuncargs @@ -322,6 +324,12 @@ def get(self, key, **kw): return self._get_value(key, **kw).get_value() get_value = get + if sys.version_info[0] == 3 and sys.version_info[1] > 4: + async def aget(self, key, **kw): + """Retrieve a cached value from the container""" + return await self._get_value(key, **kw).aget_value() + aget_value = aget + def remove_value(self, key, **kw): mycontainer = self._get_value(key, **kw) mycontainer.clear_value() @@ -547,28 +555,42 @@ def _cache_decorate(deco_args, manager, options, region): cache = [None] - def decorate(func): - namespace = util.func_namespace(func) - skip_self = util.has_self_arg(func) - signature = func_signature(func) + def _get_cache_region(region): + if region is None: + return None + if region not in cache_regions: + raise BeakerException( + 'Cache region not configured: %s' % region + ) + return cache_regions[region] + + def _short_circuit(cache, region): + if not cache and region is not None: + reg = _get_cache_region(region) + if not reg.get('enabled', True): + return True + return False + + def _find_cache(namespace, region, options): + if region is not None: + reg = _get_cache_region(region) + return Cache._get_cache(namespace, reg) + elif manager: + return manager.get_cache(namespace, **options) + else: + raise Exception("'manager + kwargs' or 'region' " + "argument is required") - @wraps(func) - def cached(*args, **kwargs): - if not cache[0]: - if region is not None: - if region not in cache_regions: - raise BeakerException( - 'Cache region not configured: %s' % region) - reg = cache_regions[region] - if not reg.get('enabled', True): - return func(*args, **kwargs) - cache[0] = Cache._get_cache(namespace, reg) - elif manager: - cache[0] = manager.get_cache(namespace, **options) - else: - raise Exception("'manager + kwargs' or 'region' " - "argument is required") + def _determine_key_length(region, options): + if region: + cachereg = cache_regions[region] + key_length = cachereg.get('key_length', util.DEFAULT_CACHE_KEY_LENGTH) + else: + key_length = options.pop('key_length', util.DEFAULT_CACHE_KEY_LENGTH) + return key_length + def _cache_key_func(namespace, skip_self, signature): + def _inner(key_length, *args, **kwargs): cache_key_kwargs = [] if kwargs: # kwargs provided, merge them in positional args @@ -582,23 +604,58 @@ def cached(*args, **kwargs): cache_key = u_(" ").join(map(u_, chain(deco_args, cache_key_args, cache_key_kwargs))) - if region: - cachereg = cache_regions[region] - key_length = cachereg.get('key_length', util.DEFAULT_CACHE_KEY_LENGTH) - else: - key_length = options.pop('key_length', util.DEFAULT_CACHE_KEY_LENGTH) - - # TODO: This is probably a bug as length is checked before converting to UTF8 + # TODO: This is probably a bug as length is checked before converting to UTF-8 # which will cause cache_key to grow in size. if len(cache_key) + len(namespace) > int(key_length): cache_key = sha1(cache_key.encode('utf-8')).hexdigest() + return cache_key + return _inner + + def decorate(func): + namespace = util.func_namespace(func) + skip_self = util.has_self_arg(func) + signature = func_signature(func) + + _determine_cache_key = _cache_key_func(namespace, skip_self, signature) + + async_func = inspect.iscoroutinefunction(func) + if async_func: + @wraps(func) + async def cached(*args, **kwargs): + if _short_circuit(cache[0], region): + return await func(*args, **kwargs) + + if not cache[0]: + cache[0] = _find_cache(namespace, region, options) + + key_length = _determine_key_length(region, options) + cache_key = _determine_cache_key(key_length, *args, **kwargs) + + async def go(): + return await func(*args, **kwargs) + # save org function name + go.__name__ = '_cached_%s' % (func.__name__,) + + return await cache[0].aget_value(cache_key, createfunc=go) + else: + @wraps(func) + def cached(*args, **kwargs): + if _short_circuit(cache[0], region): + return func(*args, **kwargs) + + if not cache[0]: + cache[0] = _find_cache(namespace, region, options) + + key_length = _determine_key_length(region, options) + cache_key = _determine_cache_key(key_length, *args, **kwargs) + + def go(): + return func(*args, **kwargs) + # save org function name + go.__name__ = '_cached_%s' % (func.__name__,) - def go(): - return func(*args, **kwargs) - # save org function name - go.__name__ = '_cached_%s' % (func.__name__,) + return cache[0].get_value(cache_key, createfunc=go) - return cache[0].get_value(cache_key, createfunc=go) cached._arg_namespace = namespace if region is not None: cached._arg_region = region diff --git a/beaker/container.py b/beaker/container.py index f3f5b4f..e39b588 100644 --- a/beaker/container.py +++ b/beaker/container.py @@ -6,6 +6,7 @@ import beaker.util as util import logging import os +import sys import time from beaker.exceptions import CreationAbortedError, MissingCacheParameter @@ -328,7 +329,7 @@ def _is_expired(self, storedtime, expiretime): ) ) - def get_value(self): + def _check_cache(self): self.namespace.acquire_read_lock() try: has_value = self.has_value() @@ -336,7 +337,7 @@ def get_value(self): try: stored, expired, value = self._get_value() if not self._is_expired(stored, expired): - return value + return None, value except KeyError: # guard against un-mutexed backends raising KeyError has_value = False @@ -345,36 +346,35 @@ def get_value(self): raise KeyError(self.key) finally: self.namespace.release_read_lock() + return has_value, None - has_createlock = False + def _creation_lock_or_value(self, has_value): creation_lock = self.namespace.get_creation_lock(self.key) if has_value: if not creation_lock.acquire(wait=False): debug("get_value returning old value while new one is created") - return value + return None, value else: debug("lock_creatfunc (didnt wait)") - has_createlock = True - - if not has_createlock: + else: debug("lock_createfunc (waiting)") creation_lock.acquire() debug("lock_createfunc (waited)") + return creation_lock, None + + def get_value(self): + has_value, value = self._check_cache() + if has_value is None: + return value + + creation_lock, value = self._creation_lock_or_value(has_value) + if creation_lock is None: + return value try: - # see if someone created the value already - self.namespace.acquire_read_lock() - try: - if self.has_value(): - try: - stored, expired, value = self._get_value() - if not self._is_expired(stored, expired): - return value - except KeyError: - # guard against un-mutexed backends raising KeyError - pass - finally: - self.namespace.release_read_lock() + has_value, value = self._check_cache() + if has_value is None: + return value debug("get_value creating new value") v = self.createfunc() @@ -384,6 +384,29 @@ def get_value(self): creation_lock.release() debug("released create lock") + if sys.version_info[0] == 3 and sys.version_info[1] > 4: + async def aget_value(self): + has_value, value = self._check_cache() + if has_value is None: + return value + + creation_lock, value = self._creation_lock_or_value(has_value) + if creation_lock is None: + return value + + try: + has_value, value = self._check_cache() + if has_value is None: + return value + + debug("get_value creating new value") + v = await self.createfunc() + self.set_value(v) + return v + finally: + creation_lock.release() + debug("released create lock") + def _get_value(self): value = self.namespace[self.key] try: