Skip to content

Commit

Permalink
feat: cache aggregates by owner results until they are updated
Browse files Browse the repository at this point in the history
  • Loading branch information
Psycojoker committed Jan 31, 2025
1 parent 827832f commit 6f3b217
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 2 deletions.
28 changes: 28 additions & 0 deletions src/aleph/cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from collections import defaultdict


class Cache:
def __init__(self):
self._cache = defaultdict(dict)

def get(self, key, namespace):
return self._cache[namespace].get(key)

def set(self, key, value, namespace):
self._cache[namespace][key] = value

def exists(self, key, namespace):
return key in self._cache[namespace]

def delete_namespace(self, namespace):
if namespace in self._cache:
self._cache[namespace] = {}

def delete(self, key, namespace):
if self.exists(key, namespace):
del self._cache[namespace]


# simple in memory cache
# we can't use aiocache here because most of ORM methods are not async compatible
cache = Cache()
24 changes: 22 additions & 2 deletions src/aleph/db/accessors/aggregates.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import datetime as dt
import logging
from typing import (
Any,
Dict,
Expand All @@ -11,13 +12,22 @@
overload,
)

from sqlalchemy import delete, func, literal_column, select, update
from sqlalchemy import delete, event, func, literal_column, select, update
from sqlalchemy.dialects.postgresql import aggregate_order_by, insert
from sqlalchemy.orm import defer, selectinload

from aleph.cache import cache
from aleph.db.models import AggregateDb, AggregateElementDb
from aleph.types.db_session import DbSession

logger = logging.getLogger(__name__)


@event.listens_for(AggregateDb, "after_update", propagate=True)
@event.listens_for(AggregateDb, "after_delete", propagate=True)
def prune_cache_for_updated_aggregates(mapper, connection, target):
cache.delete_namespace(f"aggregates_by_owner:{target.owner}")


def aggregate_exists(session: DbSession, key: str, owner: str) -> bool:
return AggregateDb.exists(
Expand Down Expand Up @@ -55,6 +65,14 @@ def get_aggregates_by_owner(


def get_aggregates_by_owner(session, owner, with_info, keys=None):
cache_key = f"{with_info} {keys}"

if (
aggregates := cache.get(cache_key, namespace=f"aggregates_by_owner:{owner}")
) is not None:
logging.debug(f"cache hit for aggregates_by_owner on cache key {cache_key}")
return aggregates

where_clause = AggregateDb.owner == owner
if keys:
where_clause = where_clause & AggregateDb.key.in_(keys)
Expand All @@ -80,7 +98,9 @@ def get_aggregates_by_owner(session, owner, with_info, keys=None):
.filter(where_clause)
.order_by(AggregateDb.key)
)
return query.all()
result = query.all()
cache.set(cache_key, result, namespace="aggregates_by_owner:{owner}")
return result


def get_aggregate_by_key(
Expand Down

0 comments on commit 6f3b217

Please sign in to comment.