Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Caching async functions #206

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 89 additions & 32 deletions beaker/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
:func:`.region_invalidate`.

"""
import inspect
import warnings
import sys
from itertools import chain

from beaker._compat import u_, unicode_text, func_signature, bindfuncargs
Expand Down Expand Up @@ -322,6 +324,12 @@ def get(self, key, **kw):
return self._get_value(key, **kw).get_value()
get_value = get

if sys.version_info[0] == 3 and sys.version_info[1] > 4:
async def aget(self, key, **kw):
"""Retrieve a cached value from the container"""
return await self._get_value(key, **kw).aget_value()
aget_value = aget

def remove_value(self, key, **kw):
mycontainer = self._get_value(key, **kw)
mycontainer.clear_value()
Expand Down Expand Up @@ -547,28 +555,42 @@ def _cache_decorate(deco_args, manager, options, region):

cache = [None]

def decorate(func):
namespace = util.func_namespace(func)
skip_self = util.has_self_arg(func)
signature = func_signature(func)
def _get_cache_region(region):
if region is None:
return None
if region not in cache_regions:
raise BeakerException(
'Cache region not configured: %s' % region
)
return cache_regions[region]

def _short_circuit(cache, region):
if not cache and region is not None:
reg = _get_cache_region(region)
if not reg.get('enabled', True):
return True
return False

def _find_cache(namespace, region, options):
if region is not None:
reg = _get_cache_region(region)
return Cache._get_cache(namespace, reg)
elif manager:
return manager.get_cache(namespace, **options)
else:
raise Exception("'manager + kwargs' or 'region' "
"argument is required")

@wraps(func)
def cached(*args, **kwargs):
if not cache[0]:
if region is not None:
if region not in cache_regions:
raise BeakerException(
'Cache region not configured: %s' % region)
reg = cache_regions[region]
if not reg.get('enabled', True):
return func(*args, **kwargs)
cache[0] = Cache._get_cache(namespace, reg)
elif manager:
cache[0] = manager.get_cache(namespace, **options)
else:
raise Exception("'manager + kwargs' or 'region' "
"argument is required")
def _determine_key_length(region, options):
if region:
cachereg = cache_regions[region]
key_length = cachereg.get('key_length', util.DEFAULT_CACHE_KEY_LENGTH)
else:
key_length = options.pop('key_length', util.DEFAULT_CACHE_KEY_LENGTH)
return key_length

def _cache_key_func(namespace, skip_self, signature):
def _inner(key_length, *args, **kwargs):
cache_key_kwargs = []
if kwargs:
# kwargs provided, merge them in positional args
Expand All @@ -582,23 +604,58 @@ def cached(*args, **kwargs):

cache_key = u_(" ").join(map(u_, chain(deco_args, cache_key_args, cache_key_kwargs)))

if region:
cachereg = cache_regions[region]
key_length = cachereg.get('key_length', util.DEFAULT_CACHE_KEY_LENGTH)
else:
key_length = options.pop('key_length', util.DEFAULT_CACHE_KEY_LENGTH)

# TODO: This is probably a bug as length is checked before converting to UTF8
# TODO: This is probably a bug as length is checked before converting to UTF-8
# which will cause cache_key to grow in size.
if len(cache_key) + len(namespace) > int(key_length):
cache_key = sha1(cache_key.encode('utf-8')).hexdigest()
return cache_key
return _inner

def decorate(func):
namespace = util.func_namespace(func)
skip_self = util.has_self_arg(func)
signature = func_signature(func)

_determine_cache_key = _cache_key_func(namespace, skip_self, signature)

async_func = inspect.iscoroutinefunction(func)
if async_func:
@wraps(func)
async def cached(*args, **kwargs):
if _short_circuit(cache[0], region):
return await func(*args, **kwargs)

if not cache[0]:
cache[0] = _find_cache(namespace, region, options)

key_length = _determine_key_length(region, options)
cache_key = _determine_cache_key(key_length, *args, **kwargs)

async def go():
return await func(*args, **kwargs)
# save org function name
go.__name__ = '_cached_%s' % (func.__name__,)

return await cache[0].aget_value(cache_key, createfunc=go)
else:
@wraps(func)
def cached(*args, **kwargs):
if _short_circuit(cache[0], region):
return func(*args, **kwargs)

if not cache[0]:
cache[0] = _find_cache(namespace, region, options)

key_length = _determine_key_length(region, options)
cache_key = _determine_cache_key(key_length, *args, **kwargs)

def go():
return func(*args, **kwargs)
# save org function name
go.__name__ = '_cached_%s' % (func.__name__,)

def go():
return func(*args, **kwargs)
# save org function name
go.__name__ = '_cached_%s' % (func.__name__,)
return cache[0].get_value(cache_key, createfunc=go)

return cache[0].get_value(cache_key, createfunc=go)
cached._arg_namespace = namespace
if region is not None:
cached._arg_region = region
Expand Down
63 changes: 43 additions & 20 deletions beaker/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import beaker.util as util
import logging
import os
import sys
import time

from beaker.exceptions import CreationAbortedError, MissingCacheParameter
Expand Down Expand Up @@ -328,15 +329,15 @@ def _is_expired(self, storedtime, expiretime):
)
)

def get_value(self):
def _check_cache(self):
self.namespace.acquire_read_lock()
try:
has_value = self.has_value()
if has_value:
try:
stored, expired, value = self._get_value()
if not self._is_expired(stored, expired):
return value
return None, value
except KeyError:
# guard against un-mutexed backends raising KeyError
has_value = False
Expand All @@ -345,36 +346,35 @@ def get_value(self):
raise KeyError(self.key)
finally:
self.namespace.release_read_lock()
return has_value, None

has_createlock = False
def _creation_lock_or_value(self, has_value):
creation_lock = self.namespace.get_creation_lock(self.key)
if has_value:
if not creation_lock.acquire(wait=False):
debug("get_value returning old value while new one is created")
return value
return None, value
else:
debug("lock_creatfunc (didnt wait)")
has_createlock = True

if not has_createlock:
else:
debug("lock_createfunc (waiting)")
creation_lock.acquire()
debug("lock_createfunc (waited)")
return creation_lock, None

def get_value(self):
has_value, value = self._check_cache()
if has_value is None:
return value

creation_lock, value = self._creation_lock_or_value(has_value)
if creation_lock is None:
return value

try:
# see if someone created the value already
self.namespace.acquire_read_lock()
try:
if self.has_value():
try:
stored, expired, value = self._get_value()
if not self._is_expired(stored, expired):
return value
except KeyError:
# guard against un-mutexed backends raising KeyError
pass
finally:
self.namespace.release_read_lock()
has_value, value = self._check_cache()
if has_value is None:
return value

debug("get_value creating new value")
v = self.createfunc()
Expand All @@ -384,6 +384,29 @@ def get_value(self):
creation_lock.release()
debug("released create lock")

if sys.version_info[0] == 3 and sys.version_info[1] > 4:
async def aget_value(self):
has_value, value = self._check_cache()
if has_value is None:
return value

creation_lock, value = self._creation_lock_or_value(has_value)
if creation_lock is None:
return value

try:
has_value, value = self._check_cache()
if has_value is None:
return value

debug("get_value creating new value")
v = await self.createfunc()
self.set_value(v)
return v
finally:
creation_lock.release()
debug("released create lock")

def _get_value(self):
value = self.namespace[self.key]
try:
Expand Down